diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index eef05687..cb8f61c4 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -171,7 +171,7 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error) { host, _, _ := net.SplitHostPort(v.addr) - if v.isXTLSEnabled() && !isH2 { + if v.isLegacyXTLSEnabled() && !isH2 { xtlsOpts := vless.XTLSConfig{ Host: host, SkipCertVerify: v.option.SkipCertVerify, @@ -206,8 +206,8 @@ func (v *Vless) streamTLSOrXTLSConn(conn net.Conn, isH2 bool) (net.Conn, error) return conn, nil } -func (v *Vless) isXTLSEnabled() bool { - return v.client.Addons != nil +func (v *Vless) isLegacyXTLSEnabled() bool { + return v.client.Addons != nil && v.client.Addons.Flow != vless.XRV } // DialContext implements C.ProxyAdapter diff --git a/common/buf/sing.go b/common/buf/sing.go index b5e015f5..ccd2d368 100644 --- a/common/buf/sing.go +++ b/common/buf/sing.go @@ -7,6 +7,7 @@ import ( type Buffer = buf.Buffer +var StackNew = buf.StackNew var StackNewSize = buf.StackNewSize var KeepAlive = common.KeepAlive diff --git a/transport/vless/conn.go b/transport/vless/conn.go index e063d465..2aef9f9b 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -1,12 +1,21 @@ package vless import ( + "bytes" + gotls "crypto/tls" "encoding/binary" "errors" "fmt" + tlsC "github.com/Dreamacro/clash/component/tls" + "github.com/Dreamacro/clash/log" + utls "github.com/refraction-networking/utls" + "github.com/sagernet/sing/common/network" "io" "net" + "reflect" "sync" + "time" + "unsafe" "github.com/Dreamacro/clash/common/buf" N "github.com/Dreamacro/clash/common/net" @@ -17,7 +26,9 @@ import ( ) type Conn struct { - N.ExtendedConn + network.ExtendedWriter + network.ExtendedReader + net.Conn dst *DstAddr id *uuid.UUID addons *Addons @@ -26,30 +37,49 @@ type Conn struct { handshake chan struct{} handshakeMutex sync.Mutex err error + + tlsConn net.Conn + input *bytes.Reader + rawInput *bytes.Buffer + + packetsToFilter int + isTLS bool + isTLS12orAbove bool + enableXTLS bool + cipher uint16 + remainingServerHello uint16 + readRemainingContent uint16 + readRemainingPadding uint16 + readFilterUUID bool + readDirect bool + writeFilterApplicationData bool + writeDirect bool } func (vc *Conn) Read(b []byte) (int, error) { if vc.received { - return vc.ExtendedConn.Read(b) + + return vc.ExtendedReader.Read(b) } if err := vc.recvResponse(); err != nil { return 0, err } vc.received = true - return vc.ExtendedConn.Read(b) + return vc.Read(b) } func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error { if vc.received { - return vc.ExtendedConn.ReadBuffer(buffer) + + return vc.ExtendedReader.ReadBuffer(buffer) } if err := vc.recvResponse(); err != nil { return err } vc.received = true - return vc.ExtendedConn.ReadBuffer(buffer) + return vc.ReadBuffer(buffer) } func (vc *Conn) Write(p []byte) (int, error) { @@ -66,7 +96,19 @@ func (vc *Conn) Write(p []byte) (int, error) { return 0, vc.err } } - return vc.ExtendedConn.Write(p) + if vc.writeFilterApplicationData { + _buffer := buf.StackNew() + defer buf.KeepAlive(_buffer) + buffer := buf.Dup(_buffer) + defer buffer.Release() + buffer.Write(p) + err := vc.WriteBuffer(buffer) + if err != nil { + return 0, err + } + return len(p), nil + } + return vc.ExtendedWriter.Write(p) } func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error { @@ -80,7 +122,48 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error { return vc.err } } - return vc.ExtendedConn.WriteBuffer(buffer) + if vc.writeFilterApplicationData && vc.isTLS { + buffer2 := ReshapeBuffer(buffer) + defer buffer2.Release() + if buffer.Len() > 6 && bytes.Equal(buffer.To(3), tlsApplicationDataStart) { + command := commandPaddingEnd + if vc.enableXTLS { + command = commandPaddingDirect + vc.writeDirect = true + } + vc.writeFilterApplicationData = false + ApplyPadding(buffer, command, vc.id) + } + err := vc.ExtendedWriter.WriteBuffer(buffer) + if err != nil { + return err + } + if vc.writeDirect { + vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) + time.Sleep(5 * time.Millisecond) + } + if buffer2 != nil { + if vc.writeDirect { + return vc.ExtendedWriter.WriteBuffer(buffer2) + } + if buffer2.Len() > 6 && bytes.Equal(buffer2.To(3), tlsApplicationDataStart) { + command := commandPaddingEnd + if vc.enableXTLS { + command = commandPaddingDirect + vc.writeDirect = true + } + vc.writeFilterApplicationData = false + ApplyPadding(buffer2, command, vc.id) + } + err = vc.ExtendedWriter.WriteBuffer(buffer2) + } + if vc.writeDirect { + vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) + time.Sleep(5 * time.Millisecond) + } + return err + } + return vc.ExtendedWriter.WriteBuffer(buffer) } func (vc *Conn) sendRequest(p []byte) bool { @@ -96,9 +179,6 @@ func (vc *Conn) sendRequest(p []byte) bool { } defer close(vc.handshake) - requestLen := 1 // protocol version - requestLen += 16 // UUID - requestLen += 1 // addons length var addonsBytes []byte if vc.addons != nil { addonsBytes, vc.err = proto.Marshal(vc.addons) @@ -106,19 +186,32 @@ func (vc *Conn) sendRequest(p []byte) bool { return true } } - requestLen += len(addonsBytes) - requestLen += 1 // command - if !vc.dst.Mux { - requestLen += 2 // port - requestLen += 1 // addr type - requestLen += len(vc.dst.Addr) - } - requestLen += len(p) + isVision := vc.IsXTLSVisionEnabled() - _buffer := buf.StackNewSize(requestLen) - defer buf.KeepAlive(_buffer) - buffer := buf.Dup(_buffer) - defer buffer.Release() + var buffer *buf.Buffer + if isVision { + _buffer := buf.StackNew() + defer buf.KeepAlive(_buffer) + buffer = buf.Dup(_buffer) + defer buffer.Release() + } else { + requestLen := 1 // protocol version + requestLen += 16 // UUID + requestLen += 1 // addons length + requestLen += len(addonsBytes) + requestLen += 1 // command + if !vc.dst.Mux { + requestLen += 2 // port + requestLen += 1 // addr type + requestLen += len(vc.dst.Addr) + } + requestLen += len(p) + + _buffer := buf.StackNewSize(requestLen) + defer buf.KeepAlive(_buffer) + buffer = buf.Dup(_buffer) + defer buffer.Release() + } buf.Must( buffer.WriteByte(Version), // protocol version @@ -143,15 +236,52 @@ func (vc *Conn) sendRequest(p []byte) bool { ) } - buf.Must(buf.Error(buffer.Write(p))) + if isVision && !vc.dst.UDP && !vc.dst.Mux { + if len(p) == 0 { + vc.packetsToFilter = 0 + vc.writeFilterApplicationData = false + WriteWithPadding(buffer, nil, commandPaddingEnd, vc.id) + } else { + vc.FilterTLS(p) + if vc.isTLS { + WriteWithPadding(buffer, p, commandPaddingContinue, vc.id) + } else { + buf.Must(buf.Error(buffer.Write(p))) + } + } + } else { + buf.Must(buf.Error(buffer.Write(p))) + } - _, vc.err = vc.ExtendedConn.Write(buffer.Bytes()) + _, vc.err = vc.ExtendedWriter.Write(buffer.Bytes()) + if vc.err != nil { + return true + } + if isVision { + switch underlying := vc.tlsConn.(type) { + case *gotls.Conn: + if underlying.ConnectionState().Version != gotls.VersionTLS13 { + vc.err = ErrNotTLS13 + } + case *utls.UConn: + if underlying.ConnectionState().Version != utls.VersionTLS13 { + vc.err = ErrNotTLS13 + } + case *tlsC.UConn: + if underlying.ConnectionState().Version != utls.VersionTLS13 { + vc.err = ErrNotTLS13 + } + default: + vc.err = fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, vc.addons.Flow) + } + vc.tlsConn = nil + } return true } func (vc *Conn) recvResponse() error { var buf [1]byte - _, vc.err = io.ReadFull(vc.ExtendedConn, buf[:]) + _, vc.err = io.ReadFull(vc.ExtendedReader, buf[:]) if vc.err != nil { return vc.err } @@ -160,30 +290,43 @@ func (vc *Conn) recvResponse() error { return errors.New("unexpected response version") } - _, vc.err = io.ReadFull(vc.ExtendedConn, buf[:]) + _, vc.err = io.ReadFull(vc.ExtendedReader, buf[:]) if vc.err != nil { return vc.err } length := int64(buf[0]) if length != 0 { // addon data length > 0 - io.CopyN(io.Discard, vc.ExtendedConn, length) // just discard + io.CopyN(io.Discard, vc.ExtendedReader, length) // just discard } return nil } +func (vc *Conn) FrontHeadroom() int { + if vc.IsXTLSVisionEnabled() { + return paddingHeaderLen + } + return 0 +} + func (vc *Conn) Upstream() any { - return vc.ExtendedConn + return vc.Conn +} + +func (vc *Conn) IsXTLSVisionEnabled() bool { + return vc.addons != nil && vc.addons.Flow == XRV } // newConn return a Conn instance func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { c := &Conn{ - ExtendedConn: N.NewExtendedConn(conn), - id: client.uuid, - dst: dst, - handshake: make(chan struct{}), + ExtendedReader: N.NewExtendedReader(conn), + ExtendedWriter: N.NewExtendedWriter(conn), + Conn: conn, + id: client.uuid, + dst: dst, + handshake: make(chan struct{}), } if !dst.UDP && client.Addons != nil { @@ -204,15 +347,46 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { } else { return nil, fmt.Errorf("failed to use %s, maybe \"security\" is not \"xtls\"", client.Addons.Flow) } + case XRV: + c.packetsToFilter = 2048 + c.writeFilterApplicationData = true + var t reflect.Type + var p uintptr + switch underlying := conn.(type) { + case *gotls.Conn: + c.Conn = underlying.NetConn() + c.tlsConn = conn + t = reflect.TypeOf(underlying).Elem() + p = uintptr(unsafe.Pointer(underlying)) + case *utls.UConn: + c.Conn = underlying.NetConn() + c.tlsConn = conn + t = reflect.TypeOf(underlying.Conn).Elem() + p = uintptr(unsafe.Pointer(underlying.Conn)) + case *tlsC.UConn: + c.Conn = underlying.NetConn() + c.tlsConn = underlying.UConn + t = reflect.TypeOf(underlying.Conn).Elem() + p = uintptr(unsafe.Pointer(underlying.Conn)) + default: + return nil, fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, client.Addons.Flow) + } + i, _ := t.FieldByName("input") + r, _ := t.FieldByName("rawInput") + c.input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset)) + c.rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset)) + if _, ok := c.Conn.(*net.TCPConn); !ok { + log.Debugln("XTLS underlying conn is not *net.TCPConn, got", reflect.TypeOf(conn).Name()) + } } } - //go func() { - // select { - // case <-c.handshake: - // case <-time.After(200 * time.Millisecond): - // c.sendRequest(nil) - // } - //}() + go func() { + select { + case <-c.handshake: + case <-time.After(200 * time.Millisecond): + c.sendRequest(nil) + } + }() return c, nil } diff --git a/transport/vless/filter.go b/transport/vless/filter.go new file mode 100644 index 00000000..2659b031 --- /dev/null +++ b/transport/vless/filter.go @@ -0,0 +1,76 @@ +package vless + +import ( + "bytes" + "encoding/binary" + + log "github.com/sirupsen/logrus" +) + +var ( + tls13SupportedVersions = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04} + tlsClientHandshakeStart = []byte{0x16, 0x03} + tlsServerHandshakeStart = []byte{0x16, 0x03, 0x03} + tlsApplicationDataStart = []byte{0x17, 0x03, 0x03} + + tls13CipherSuiteMap = map[uint16]string{ + 0x1301: "TLS_AES_128_GCM_SHA256", + 0x1302: "TLS_AES_256_GCM_SHA384", + 0x1303: "TLS_CHACHA20_POLY1305_SHA256", + 0x1304: "TLS_AES_128_CCM_SHA256", + 0x1305: "TLS_AES_128_CCM_8_SHA256", + } +) + +const ( + tlsHandshakeTypeClientHello byte = 0x01 + tlsHandshakeTypeServerHello byte = 0x02 +) + +func (vc *Conn) FilterTLS(p []byte) (index int) { + lenP := len(p) + vc.packetsToFilter -= 1 + if index = bytes.Index(p, tlsServerHandshakeStart); index != -1 { + if lenP >= index+5 && p[index+5] == tlsHandshakeTypeServerHello { + vc.remainingServerHello = binary.BigEndian.Uint16(p[index+3:]) + 5 + vc.isTLS = true + vc.isTLS12orAbove = true + if lenP-index >= 79 && vc.remainingServerHello >= 79 { + sessionIDLen := int(p[index+43]) + vc.cipher = binary.BigEndian.Uint16(p[index+43+sessionIDLen+1:]) + } + } + } else if index = bytes.Index(p, tlsClientHandshakeStart); index != -1 { + if lenP >= index+5 && p[index+5] == tlsHandshakeTypeClientHello { + vc.isTLS = true + } + } + + if vc.remainingServerHello > 0 { + end := vc.remainingServerHello + vc.remainingServerHello -= end + if end > uint16(lenP) { + end = uint16(lenP) + } + if bytes.Contains(p[index:end], tls13SupportedVersions) { + // TLS 1.3 Client Hello + cs, ok := tls13CipherSuiteMap[vc.cipher] + if ok && cs != "TLS_AES_128_CCM_8_SHA256" { + vc.enableXTLS = true + } + log.Debugln("XTLS Vision found TLS 1.3, packetLength=", lenP, ", CipherSuite=", cs) + vc.packetsToFilter = 0 + return + } else if vc.remainingServerHello < 0 { + log.Debugln("XTLS Vision found TLS 1.2, packetLength=", lenP) + vc.packetsToFilter = 0 + return + } + log.Debugln("XTLS Vision found inconclusive server hello, packetLength=", lenP, + ", remainingServerHelloBytes=", vc.remainingServerHello) + } + if vc.packetsToFilter <= 0 { + log.Debugln("XTLS Vision stop filtering") + } + return +} diff --git a/transport/vless/vision.go b/transport/vless/vision.go new file mode 100644 index 00000000..20f73613 --- /dev/null +++ b/transport/vless/vision.go @@ -0,0 +1,67 @@ +package vless + +import ( + "bytes" + "encoding/binary" + "math/rand" + + "github.com/Dreamacro/clash/common/buf" + + "github.com/gofrs/uuid" + buf2 "github.com/sagernet/sing/common/buf" +) + +const ( + paddingHeaderLen = 16 + 1 + 2 + 2 // =21 + + commandPaddingContinue byte = 0x00 + commandPaddingEnd byte = 0x01 + commandPaddingDirect byte = 0x02 +) + +func WriteWithPadding(buffer *buf.Buffer, p []byte, command byte, userUUID *uuid.UUID) { + contentLen := int32(len(p)) + var paddingLen int32 + if contentLen < 900 { + paddingLen = rand.Int31n(500) + 900 - contentLen + } + + if userUUID != nil { // unnecessary, but keep the same with Xray + buffer.Write(userUUID.Bytes()) + } + buffer.WriteByte(command) + binary.BigEndian.PutUint16(buffer.Extend(2), uint16(contentLen)) + binary.BigEndian.PutUint16(buffer.Extend(2), uint16(paddingLen)) + buffer.Write(p) + buffer.Extend(int(paddingLen)) +} + +func ApplyPadding(buffer *buf.Buffer, command byte, userUUID *uuid.UUID) { + contentLen := int32(buffer.Len()) + var paddingLen int32 + if contentLen < 900 { + paddingLen = rand.Int31n(500) + 900 - contentLen + } + + binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(paddingLen)) + binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(contentLen)) + buffer.ExtendHeader(1)[0] = command + if userUUID != nil { // unnecessary, but keep the same with Xray + copy(buffer.ExtendHeader(uuid.Size), userUUID.Bytes()) + } + buffer.Extend(int(paddingLen)) +} + +func ReshapeBuffer(buffer *buf.Buffer) *buf.Buffer { + if buffer.Len() <= buf2.BufferSize-paddingHeaderLen { + return nil + } + cutAt := bytes.LastIndex(buffer.Bytes(), tlsApplicationDataStart) + if cutAt == -1 { + cutAt = buf2.BufferSize / 2 + } + buffer2 := buf2.New() + buffer2.Write(buffer.From(cutAt)) + buffer.Truncate(cutAt) + return buffer2 +} diff --git a/transport/vless/vless.go b/transport/vless/vless.go index 4b101703..6989374c 100644 --- a/transport/vless/vless.go +++ b/transport/vless/vless.go @@ -12,6 +12,7 @@ const ( XRO = "xtls-rprx-origin" XRD = "xtls-rprx-direct" XRS = "xtls-rprx-splice" + XRV = "xtls-rprx-vision" Version byte = 0 // protocol version. preview version is 0 ) diff --git a/transport/vless/xtls.go b/transport/vless/xtls.go index a1aea44f..3a319568 100644 --- a/transport/vless/xtls.go +++ b/transport/vless/xtls.go @@ -2,6 +2,7 @@ package vless import ( "context" + "errors" "net" tlsC "github.com/Dreamacro/clash/component/tls" @@ -9,6 +10,10 @@ import ( xtls "github.com/xtls/go" ) +var ( + ErrNotTLS13 = errors.New("XTLS Vision based on TLS 1.3 outer connection") +) + type XTLSConfig struct { Host string SkipCertVerify bool diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 97dd7316..133a8b8a 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -28,7 +28,7 @@ type trackerInfo struct { RulePayload string `json:"rulePayload"` } -type tcpTracker struct { +type TCPTracker struct { C.Conn `json:"-"` *trackerInfo manager *Manager @@ -36,11 +36,16 @@ type tcpTracker struct { extendedWriter N.ExtendedWriter } -func (tt *tcpTracker) ID() string { +func (tt *TCPTracker) ID() string { return tt.UUID.String() } -func (tt *tcpTracker) Read(b []byte) (int, error) { +func (tt *TCPTracker) AddDownload(n int64) { + tt.manager.PushDownloaded(n) + tt.DownloadTotal.Add(n) +} + +func (tt *TCPTracker) Read(b []byte) (int, error) { n, err := tt.Conn.Read(b) download := int64(n) tt.manager.PushDownloaded(download) @@ -48,7 +53,7 @@ func (tt *tcpTracker) Read(b []byte) (int, error) { return n, err } -func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) { +func (tt *TCPTracker) ReadBuffer(buffer *buf.Buffer) (err error) { err = tt.extendedReader.ReadBuffer(buffer) download := int64(buffer.Len()) tt.manager.PushDownloaded(download) @@ -56,7 +61,12 @@ func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) { return } -func (tt *tcpTracker) Write(b []byte) (int, error) { +func (tt *TCPTracker) AddUpload(n int64) { + tt.manager.PushUploaded(n) + tt.UploadTotal.Add(n) +} + +func (tt *TCPTracker) Write(b []byte) (int, error) { n, err := tt.Conn.Write(b) upload := int64(n) tt.manager.PushUploaded(upload) @@ -64,7 +74,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { return n, err } -func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) { +func (tt *TCPTracker) WriteBuffer(buffer *buf.Buffer) (err error) { upload := int64(buffer.Len()) err = tt.extendedWriter.WriteBuffer(buffer) tt.manager.PushUploaded(upload) @@ -72,16 +82,16 @@ func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) { return } -func (tt *tcpTracker) Close() error { +func (tt *TCPTracker) Close() error { tt.manager.Leave(tt) return tt.Conn.Close() } -func (tt *tcpTracker) Upstream() any { +func (tt *TCPTracker) Upstream() any { return tt.Conn } -func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker { +func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *TCPTracker { uuid, _ := uuid.NewV4() if conn != nil { if tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr); ok { @@ -91,7 +101,7 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R } } - t := &tcpTracker{ + t := &TCPTracker{ Conn: conn, manager: manager, trackerInfo: &trackerInfo{