Chore: split component to transport
This commit is contained in:
223
transport/trojan/trojan.go
Normal file
223
transport/trojan/trojan.go
Normal file
@ -0,0 +1,223 @@
|
||||
package trojan
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/Dreamacro/clash/transport/socks5"
|
||||
)
|
||||
|
||||
const (
|
||||
// max packet length
|
||||
maxLength = 8192
|
||||
)
|
||||
|
||||
var (
|
||||
defaultALPN = []string{"h2", "http/1.1"}
|
||||
crlf = []byte{'\r', '\n'}
|
||||
|
||||
bufPool = sync.Pool{New: func() interface{} { return &bytes.Buffer{} }}
|
||||
)
|
||||
|
||||
type Command = byte
|
||||
|
||||
var (
|
||||
CommandTCP byte = 1
|
||||
CommandUDP byte = 3
|
||||
)
|
||||
|
||||
type Option struct {
|
||||
Password string
|
||||
ALPN []string
|
||||
ServerName string
|
||||
SkipCertVerify bool
|
||||
ClientSessionCache tls.ClientSessionCache
|
||||
}
|
||||
|
||||
type Trojan struct {
|
||||
option *Option
|
||||
hexPassword []byte
|
||||
}
|
||||
|
||||
func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) {
|
||||
alpn := defaultALPN
|
||||
if len(t.option.ALPN) != 0 {
|
||||
alpn = t.option.ALPN
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: alpn,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
InsecureSkipVerify: t.option.SkipCertVerify,
|
||||
ServerName: t.option.ServerName,
|
||||
ClientSessionCache: t.option.ClientSessionCache,
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(conn, tlsConfig)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error {
|
||||
buf := bufPool.Get().(*bytes.Buffer)
|
||||
defer bufPool.Put(buf)
|
||||
defer buf.Reset()
|
||||
|
||||
buf.Write(t.hexPassword)
|
||||
buf.Write(crlf)
|
||||
|
||||
buf.WriteByte(command)
|
||||
buf.Write(socks5Addr)
|
||||
buf.Write(crlf)
|
||||
|
||||
_, err := w.Write(buf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn {
|
||||
return &PacketConn{
|
||||
Conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
|
||||
buf := bufPool.Get().(*bytes.Buffer)
|
||||
defer bufPool.Put(buf)
|
||||
defer buf.Reset()
|
||||
|
||||
buf.Write(socks5Addr)
|
||||
binary.Write(buf, binary.BigEndian, uint16(len(payload)))
|
||||
buf.Write(crlf)
|
||||
buf.Write(payload)
|
||||
|
||||
return w.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
|
||||
if len(payload) <= maxLength {
|
||||
return writePacket(w, socks5Addr, payload)
|
||||
}
|
||||
|
||||
offset := 0
|
||||
total := len(payload)
|
||||
for {
|
||||
cursor := offset + maxLength
|
||||
if cursor > total {
|
||||
cursor = total
|
||||
}
|
||||
|
||||
n, err := writePacket(w, socks5Addr, payload[offset:cursor])
|
||||
if err != nil {
|
||||
return offset + n, err
|
||||
}
|
||||
|
||||
offset = cursor
|
||||
if offset == total {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) {
|
||||
addr, err := socks5.ReadAddr(r, payload)
|
||||
if err != nil {
|
||||
return nil, 0, 0, errors.New("read addr error")
|
||||
}
|
||||
uAddr := addr.UDPAddr()
|
||||
|
||||
if _, err = io.ReadFull(r, payload[:2]); err != nil {
|
||||
return nil, 0, 0, errors.New("read length error")
|
||||
}
|
||||
|
||||
total := int(binary.BigEndian.Uint16(payload[:2]))
|
||||
if total > maxLength {
|
||||
return nil, 0, 0, errors.New("packet invalid")
|
||||
}
|
||||
|
||||
// read crlf
|
||||
if _, err = io.ReadFull(r, payload[:2]); err != nil {
|
||||
return nil, 0, 0, errors.New("read crlf error")
|
||||
}
|
||||
|
||||
length := len(payload)
|
||||
if total < length {
|
||||
length = total
|
||||
}
|
||||
|
||||
if _, err = io.ReadFull(r, payload[:length]); err != nil {
|
||||
return nil, 0, 0, errors.New("read packet error")
|
||||
}
|
||||
|
||||
return uAddr, length, total - length, nil
|
||||
}
|
||||
|
||||
func New(option *Option) *Trojan {
|
||||
return &Trojan{option, hexSha224([]byte(option.Password))}
|
||||
}
|
||||
|
||||
type PacketConn struct {
|
||||
net.Conn
|
||||
remain int
|
||||
rAddr net.Addr
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
|
||||
}
|
||||
|
||||
func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
pc.mux.Lock()
|
||||
defer pc.mux.Unlock()
|
||||
if pc.remain != 0 {
|
||||
length := len(b)
|
||||
if pc.remain < length {
|
||||
length = pc.remain
|
||||
}
|
||||
|
||||
n, err := pc.Conn.Read(b[:length])
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
pc.remain -= n
|
||||
addr := pc.rAddr
|
||||
if pc.remain == 0 {
|
||||
pc.rAddr = nil
|
||||
}
|
||||
|
||||
return n, addr, nil
|
||||
}
|
||||
|
||||
addr, n, remain, err := ReadPacket(pc.Conn, b)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if remain != 0 {
|
||||
pc.remain = remain
|
||||
pc.rAddr = addr
|
||||
}
|
||||
|
||||
return n, addr, nil
|
||||
}
|
||||
|
||||
func hexSha224(data []byte) []byte {
|
||||
buf := make([]byte, 56)
|
||||
hash := sha256.New224()
|
||||
hash.Write(data)
|
||||
hex.Encode(buf, hash.Sum(nil))
|
||||
return buf
|
||||
}
|
Reference in New Issue
Block a user