refactor: Move vision implementation to a new package

This commit is contained in:
H1JK
2023-05-26 20:11:06 +08:00
parent 984bf27d9b
commit 654e76d91e
6 changed files with 362 additions and 300 deletions

View File

@ -0,0 +1,274 @@
package vision
import (
"bytes"
"crypto/subtle"
gotls "crypto/tls"
"encoding/binary"
"fmt"
"io"
"net"
"github.com/Dreamacro/clash/common/buf"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/log"
"github.com/gofrs/uuid/v5"
utls "github.com/sagernet/utls"
)
var (
_ N.ExtendedConn = (*Conn)(nil)
)
type Conn struct {
net.Conn
N.ExtendedReader
N.ExtendedWriter
upstream net.Conn
userUUID *uuid.UUID
tlsConn net.Conn
input *bytes.Reader
rawInput *bytes.Buffer
needHandshake bool
packetsToFilter int
isTLS bool
isTLS12orAbove bool
enableXTLS bool
cipher uint16
remainingServerHello uint16
readRemainingContent int
readRemainingPadding int
readProcess bool
readFilterUUID bool
readLastCommand byte
writeFilterApplicationData bool
writeDirect bool
}
func (vc *Conn) Read(b []byte) (int, error) {
if vc.readProcess {
buffer := buf.With(b)
err := vc.ReadBuffer(buffer)
return buffer.Len(), err
}
return vc.ExtendedReader.Read(b)
}
func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
toRead := buffer.FreeBytes()
if vc.readRemainingContent > 0 {
if vc.readRemainingContent < buffer.FreeLen() {
toRead = toRead[:vc.readRemainingContent]
}
n, err := vc.ExtendedReader.Read(toRead)
buffer.Truncate(n)
vc.readRemainingContent -= n
vc.FilterTLS(toRead)
return err
}
if vc.readRemainingPadding > 0 {
_, err := io.CopyN(io.Discard, vc.ExtendedReader, int64(vc.readRemainingPadding))
if err != nil {
return err
}
vc.readRemainingPadding = 0
}
if vc.readProcess {
switch vc.readLastCommand {
case commandPaddingContinue:
//if vc.isTLS || vc.packetsToFilter > 0 {
headerUUIDLen := 0
if vc.readFilterUUID {
headerUUIDLen = uuid.Size
}
var header []byte
if need := headerUUIDLen + PaddingHeaderLen - uuid.Size; buffer.FreeLen() < need {
header = make([]byte, need)
} else {
header = buffer.FreeBytes()[:need]
}
_, err := io.ReadFull(vc.ExtendedReader, header)
if err != nil {
return err
}
if vc.readFilterUUID {
vc.readFilterUUID = false
if subtle.ConstantTimeCompare(vc.userUUID.Bytes(), header[:uuid.Size]) != 1 {
err = fmt.Errorf("XTLS Vision server responded unknown UUID: %s",
uuid.FromBytesOrNil(header[:uuid.Size]).String())
log.Errorln(err.Error())
return err
}
header = header[uuid.Size:]
}
vc.readRemainingPadding = int(binary.BigEndian.Uint16(header[3:]))
vc.readRemainingContent = int(binary.BigEndian.Uint16(header[1:]))
vc.readLastCommand = header[0]
log.Debugln("XTLS Vision read padding: command=%d, payloadLen=%d, paddingLen=%d",
vc.readLastCommand, vc.readRemainingContent, vc.readRemainingPadding)
return vc.ReadBuffer(buffer)
//}
case commandPaddingEnd:
vc.readProcess = false
return vc.ReadBuffer(buffer)
case commandPaddingDirect:
needReturn := false
if vc.input != nil {
_, err := buffer.ReadFrom(vc.input)
if err != nil {
return err
}
if vc.input.Len() == 0 {
needReturn = true
vc.input = nil
} else { // buffer is full
return nil
}
}
if vc.rawInput != nil {
_, err := buffer.ReadFrom(vc.rawInput)
if err != nil {
return err
}
needReturn = true
if vc.rawInput.Len() == 0 {
vc.rawInput = nil
}
}
if vc.input == nil && vc.rawInput == nil {
vc.readProcess = false
vc.ExtendedReader = N.NewExtendedReader(vc.Conn)
log.Debugln("XTLS Vision direct read start")
}
if needReturn {
return nil
}
default:
err := fmt.Errorf("XTLS Vision read unknown command: %d", vc.readLastCommand)
log.Debugln(err.Error())
return err
}
}
return vc.ExtendedReader.ReadBuffer(buffer)
}
func (vc *Conn) Write(p []byte) (int, error) {
if vc.writeFilterApplicationData {
_buffer := buf.StackNew()
defer buf.KeepAlive(_buffer)
buffer := buf.Dup(_buffer)
defer buffer.Release()
buffer.Write(p)
err := vc.WriteBuffer(buffer)
if err != nil {
return 0, err
}
return len(p), nil
}
return vc.ExtendedWriter.Write(p)
}
func (vc *Conn) WriteBuffer(buffer *buf.Buffer) (err error) {
if vc.needHandshake {
vc.needHandshake = false
if buffer.IsEmpty() {
ApplyPadding(buffer, commandPaddingContinue, vc.userUUID, false)
} else {
vc.FilterTLS(buffer.Bytes())
ApplyPadding(buffer, commandPaddingContinue, vc.userUUID, vc.isTLS)
}
err = vc.ExtendedWriter.WriteBuffer(buffer)
if err != nil {
buffer.Release()
return err
}
switch underlying := vc.tlsConn.(type) {
case *gotls.Conn:
if underlying.ConnectionState().Version != gotls.VersionTLS13 {
buffer.Release()
return ErrNotTLS13
}
case *utls.UConn:
if underlying.ConnectionState().Version != utls.VersionTLS13 {
buffer.Release()
return ErrNotTLS13
}
}
vc.tlsConn = nil
return nil
}
if vc.writeFilterApplicationData {
buffer2 := ReshapeBuffer(buffer)
defer buffer2.Release()
vc.FilterTLS(buffer.Bytes())
command := commandPaddingContinue
if !vc.isTLS {
command = commandPaddingEnd
// disable XTLS
//vc.readProcess = false
vc.writeFilterApplicationData = false
vc.packetsToFilter = 0
} else if buffer.Len() > 6 && bytes.Equal(buffer.To(3), tlsApplicationDataStart) || vc.packetsToFilter <= 0 {
command = commandPaddingEnd
if vc.enableXTLS {
command = commandPaddingDirect
vc.writeDirect = true
}
vc.writeFilterApplicationData = false
}
ApplyPadding(buffer, command, nil, vc.isTLS)
err = vc.ExtendedWriter.WriteBuffer(buffer)
if err != nil {
return err
}
if vc.writeDirect {
vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn)
log.Debugln("XTLS Vision direct write start")
//time.Sleep(5 * time.Millisecond)
}
if buffer2 != nil {
if vc.writeDirect || !vc.isTLS {
return vc.ExtendedWriter.WriteBuffer(buffer2)
}
vc.FilterTLS(buffer2.Bytes())
command = commandPaddingContinue
if buffer2.Len() > 6 && bytes.Equal(buffer2.To(3), tlsApplicationDataStart) || vc.packetsToFilter <= 0 {
command = commandPaddingEnd
if vc.enableXTLS {
command = commandPaddingDirect
vc.writeDirect = true
}
vc.writeFilterApplicationData = false
}
ApplyPadding(buffer2, command, nil, vc.isTLS)
err = vc.ExtendedWriter.WriteBuffer(buffer2)
if vc.writeDirect {
vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn)
log.Debugln("XTLS Vision direct write start")
//time.Sleep(10 * time.Millisecond)
}
}
return err
}
/*if vc.writeDirect {
log.Debugln("XTLS Vision Direct write, payloadLen=%d", buffer.Len())
}*/
return vc.ExtendedWriter.WriteBuffer(buffer)
}
func (vc *Conn) FrontHeadroom() int {
return PaddingHeaderLen
}
func (vc *Conn) NeedHandshake() bool {
return vc.needHandshake
}
func (vc *Conn) Upstream() any {
return vc.upstream
}

View File

@ -0,0 +1,90 @@
package vision
import (
"bytes"
"encoding/binary"
"github.com/Dreamacro/clash/log"
)
var (
tls13SupportedVersions = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04}
tlsClientHandshakeStart = []byte{0x16, 0x03}
tlsServerHandshakeStart = []byte{0x16, 0x03, 0x03}
tlsApplicationDataStart = []byte{0x17, 0x03, 0x03}
tls13CipherSuiteMap = map[uint16]string{
0x1301: "TLS_AES_128_GCM_SHA256",
0x1302: "TLS_AES_256_GCM_SHA384",
0x1303: "TLS_CHACHA20_POLY1305_SHA256",
0x1304: "TLS_AES_128_CCM_SHA256",
0x1305: "TLS_AES_128_CCM_8_SHA256",
}
)
const (
tlsHandshakeTypeClientHello byte = 0x01
tlsHandshakeTypeServerHello byte = 0x02
)
func (vc *Conn) FilterTLS(buffer []byte) (index int) {
if vc.packetsToFilter <= 0 {
return 0
}
lenP := len(buffer)
vc.packetsToFilter--
if index = bytes.Index(buffer, tlsServerHandshakeStart); index != -1 {
if lenP > index+5 {
if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
vc.isTLS = true
if buffer[5] == tlsHandshakeTypeServerHello {
//log.Debugln("isTLS12orAbove")
vc.remainingServerHello = binary.BigEndian.Uint16(buffer[index+3:]) + 5
vc.isTLS12orAbove = true
if lenP-index >= 79 && vc.remainingServerHello >= 79 {
sessionIDLen := int(buffer[index+43])
vc.cipher = binary.BigEndian.Uint16(buffer[index+43+sessionIDLen+1:])
}
}
}
}
} else if index = bytes.Index(buffer, tlsClientHandshakeStart); index != -1 {
if lenP > index+5 && buffer[index+5] == tlsHandshakeTypeClientHello {
vc.isTLS = true
}
}
if vc.remainingServerHello > 0 {
end := int(vc.remainingServerHello)
i := index
if i < 0 {
i = 0
}
if i+end > lenP {
end = lenP
vc.remainingServerHello -= uint16(end - i)
} else {
vc.remainingServerHello -= uint16(end)
end += i
}
if bytes.Contains(buffer[i:end], tls13SupportedVersions) {
// TLS 1.3 Client Hello
cs, ok := tls13CipherSuiteMap[vc.cipher]
if ok && cs != "TLS_AES_128_CCM_8_SHA256" {
vc.enableXTLS = true
}
log.Debugln("XTLS Vision found TLS 1.3, packetLength=%d CipherSuite=%s", lenP, cs)
vc.packetsToFilter = 0
return
} else if vc.remainingServerHello <= 0 {
log.Debugln("XTLS Vision found TLS 1.2, packetLength=%d", lenP)
vc.packetsToFilter = 0
return
}
log.Debugln("XTLS Vision found inconclusive server hello, packetLength=%d, remainingServerHelloBytes=%d", lenP, vc.remainingServerHello)
}
if vc.packetsToFilter <= 0 {
log.Debugln("XTLS Vision stop filtering")
}
return
}

View File

@ -0,0 +1,81 @@
package vision
import (
"bytes"
"encoding/binary"
"github.com/Dreamacro/clash/common/buf"
"github.com/Dreamacro/clash/log"
"github.com/gofrs/uuid/v5"
"github.com/zhangyunhao116/fastrand"
)
const (
PaddingHeaderLen = uuid.Size + 1 + 2 + 2 // =21
commandPaddingContinue byte = 0x00
commandPaddingEnd byte = 0x01
commandPaddingDirect byte = 0x02
)
func WriteWithPadding(buffer *buf.Buffer, p []byte, command byte, userUUID *uuid.UUID, paddingTLS bool) {
contentLen := int32(len(p))
var paddingLen int32
if contentLen < 900 {
if paddingTLS {
//log.Debugln("long padding")
paddingLen = fastrand.Int31n(500) + 900 - contentLen
} else {
paddingLen = fastrand.Int31n(256)
}
}
if userUUID != nil {
buffer.Write(userUUID.Bytes())
}
buffer.WriteByte(command)
binary.BigEndian.PutUint16(buffer.Extend(2), uint16(contentLen))
binary.BigEndian.PutUint16(buffer.Extend(2), uint16(paddingLen))
buffer.Write(p)
buffer.Extend(int(paddingLen))
log.Debugln("XTLS Vision write padding1: command=%v, payloadLen=%v, paddingLen=%v", command, contentLen, paddingLen)
}
func ApplyPadding(buffer *buf.Buffer, command byte, userUUID *uuid.UUID, paddingTLS bool) {
contentLen := int32(buffer.Len())
var paddingLen int32
if contentLen < 900 {
if paddingTLS {
//log.Debugln("long padding")
paddingLen = fastrand.Int31n(500) + 900 - contentLen
} else {
paddingLen = fastrand.Int31n(256)
}
}
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(paddingLen))
binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(contentLen))
buffer.ExtendHeader(1)[0] = command
if userUUID != nil {
copy(buffer.ExtendHeader(uuid.Size), userUUID.Bytes())
}
buffer.Extend(int(paddingLen))
log.Debugln("XTLS Vision write padding2: command=%d, payloadLen=%d, paddingLen=%d", command, contentLen, paddingLen)
}
func ReshapeBuffer(buffer *buf.Buffer) *buf.Buffer {
if buffer.Len() <= buf.BufferSize-PaddingHeaderLen {
return nil
}
cutAt := bytes.LastIndex(buffer.Bytes(), tlsApplicationDataStart)
if cutAt == -1 {
cutAt = buf.BufferSize / 2
}
buffer2 := buf.New()
buffer2.Write(buffer.From(cutAt))
buffer.Truncate(cutAt)
return buffer2
}

View File

@ -0,0 +1,70 @@
// Package vision implements VLESS flow `xtls-rprx-vision` introduced by Xray-core.
package vision
import (
"bytes"
gotls "crypto/tls"
"errors"
"fmt"
"net"
"reflect"
"unsafe"
N "github.com/Dreamacro/clash/common/net"
tlsC "github.com/Dreamacro/clash/component/tls"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing/common"
utls "github.com/sagernet/utls"
)
var ErrNotTLS13 = errors.New("XTLS Vision based on TLS 1.3 outer connection")
type connWithUpstream interface {
net.Conn
common.WithUpstream
}
func NewConn(conn connWithUpstream, userUUID *uuid.UUID) (*Conn, error) {
c := &Conn{
ExtendedReader: N.NewExtendedReader(conn),
ExtendedWriter: N.NewExtendedWriter(conn),
upstream: conn,
userUUID: userUUID,
packetsToFilter: 6,
needHandshake: true,
readProcess: true,
readFilterUUID: true,
writeFilterApplicationData: true,
}
var t reflect.Type
var p unsafe.Pointer
switch underlying := conn.Upstream().(type) {
case *gotls.Conn:
//log.Debugln("type tls")
c.Conn = underlying.NetConn()
c.tlsConn = underlying
t = reflect.TypeOf(underlying).Elem()
p = unsafe.Pointer(underlying)
case *utls.UConn:
//log.Debugln("type *utls.UConn")
c.Conn = underlying.NetConn()
c.tlsConn = underlying
t = reflect.TypeOf(underlying.Conn).Elem()
p = unsafe.Pointer(underlying.Conn)
case *tlsC.UConn:
//log.Debugln("type *tlsC.UConn")
c.Conn = underlying.NetConn()
c.tlsConn = underlying.UConn
t = reflect.TypeOf(underlying.Conn).Elem()
//log.Debugln("t:%v", t)
p = unsafe.Pointer(underlying.Conn)
default:
return nil, fmt.Errorf(`failed to use vision, maybe "security" is not "tls" or "utls"`)
}
i, _ := t.FieldByName("input")
r, _ := t.FieldByName("rawInput")
c.input = (*bytes.Reader)(unsafe.Add(p, i.Offset))
c.rawInput = (*bytes.Buffer)(unsafe.Add(p, r.Offset))
return c, nil
}