This commit is contained in:
H1JK 2023-02-24 10:03:01 +08:00
parent fc58f80cc8
commit 11e0bbebf4
3 changed files with 49 additions and 31 deletions

View File

@ -53,7 +53,8 @@ type Conn struct {
remainingServerHello uint16 remainingServerHello uint16
readRemainingContent int readRemainingContent int
readRemainingPadding int readRemainingPadding int
readFilter bool readProcess bool
readFilterUUID bool
readLastCommand byte readLastCommand byte
writeFilterApplicationData bool writeFilterApplicationData bool
writeDirect bool writeDirect bool
@ -61,7 +62,7 @@ type Conn struct {
func (vc *Conn) Read(b []byte) (int, error) { func (vc *Conn) Read(b []byte) (int, error) {
if vc.received { if vc.received {
if vc.readFilter { if vc.readProcess {
buffer := buf2.As(b) buffer := buf2.As(b)
err := vc.ReadBuffer(buffer) err := vc.ReadBuffer(buffer)
return buffer.Len(), err return buffer.Len(), err
@ -86,7 +87,7 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
n, err := vc.ExtendedReader.Read(toRead) n, err := vc.ExtendedReader.Read(toRead)
buffer.Truncate(n) buffer.Truncate(n)
vc.readRemainingContent -= n vc.readRemainingContent -= n
vc.FilterTLS(buffer.Bytes()) vc.FilterTLS(toRead)
return err return err
} }
if vc.readRemainingPadding > 0 { if vc.readRemainingPadding > 0 {
@ -96,31 +97,41 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
} }
vc.readRemainingPadding = 0 vc.readRemainingPadding = 0
} }
if vc.readFilter { if vc.readProcess {
switch vc.readLastCommand { switch vc.readLastCommand {
case commandPaddingContinue: case commandPaddingContinue:
//if vc.isTLS || vc.packetsToFilter > 0 { //if vc.isTLS || vc.packetsToFilter > 0 {
header := buffer.FreeBytes()[:paddingHeaderLen] headerUUIDLen := 0
if vc.readFilterUUID {
headerUUIDLen = uuid.Size
}
header := buffer.FreeBytes()[:paddingHeaderLen+headerUUIDLen]
_, err := io.ReadFull(vc.ExtendedReader, header) _, err := io.ReadFull(vc.ExtendedReader, header)
if err != nil { if err != nil {
return err return err
} }
pos := 0
if vc.readFilterUUID {
vc.readFilterUUID = false
pos = uuid.Size
if subtle.ConstantTimeCompare(vc.id.Bytes(), header[:uuid.Size]) != 1 { if subtle.ConstantTimeCompare(vc.id.Bytes(), header[:uuid.Size]) != 1 {
return fmt.Errorf("XTLS Vision server responded unknown UUID: %s", err = fmt.Errorf("XTLS Vision server responded unknown UUID: %s",
uuid.FromBytesOrNil(header[:uuid.Size]).String()) uuid.FromBytesOrNil(header[:uuid.Size]).String())
log.Errorln(err.Error())
return err
} }
vc.readLastCommand = header[uuid.Size] }
vc.readRemainingContent = int(binary.BigEndian.Uint16(header[uuid.Size+1:])) vc.readLastCommand = header[pos]
vc.readRemainingPadding = int(binary.BigEndian.Uint16(header[uuid.Size+3:])) vc.readRemainingContent = int(binary.BigEndian.Uint16(header[pos+1:]))
vc.readRemainingPadding = int(binary.BigEndian.Uint16(header[pos+3:]))
log.Debugln("XTLS Vision read padding: command=%d, payloadLen=%d, paddingLen=%d", log.Debugln("XTLS Vision read padding: command=%d, payloadLen=%d, paddingLen=%d",
vc.readLastCommand, vc.readRemainingContent, vc.readRemainingPadding) vc.readLastCommand, vc.readRemainingContent, vc.readRemainingPadding)
return vc.ReadBuffer(buffer) return vc.ReadBuffer(buffer)
//} //}
case commandPaddingEnd: case commandPaddingEnd:
vc.readFilter = false vc.readProcess = false
return vc.ReadBuffer(buffer) return vc.ReadBuffer(buffer)
case commandPaddingDirect: case commandPaddingDirect:
log.Debugln("command read direct")
if vc.input != nil { if vc.input != nil {
_, err := buffer.ReadFrom(vc.input) _, err := buffer.ReadFrom(vc.input)
if err != nil { if err != nil {
@ -143,12 +154,14 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
} }
} }
if vc.input == nil && vc.rawInput == nil { if vc.input == nil && vc.rawInput == nil {
vc.readFilter = false vc.readProcess = false
vc.ExtendedReader = N.NewExtendedReader(vc.Conn) vc.ExtendedReader = N.NewExtendedReader(vc.Conn)
log.Debugln("XTLS Vision Direct read start") log.Debugln("XTLS Vision direct read start")
} }
default: default:
log.Debugln("XTLS Vision read unknown command: %d", vc.readLastCommand) err := fmt.Errorf("XTLS Vision read unknown command: %d", vc.readLastCommand)
log.Debugln(err.Error())
return err
} }
} }
return vc.ExtendedReader.ReadBuffer(buffer) return vc.ExtendedReader.ReadBuffer(buffer)
@ -214,14 +227,15 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
} }
vc.writeFilterApplicationData = false vc.writeFilterApplicationData = false
} }
ApplyPadding(buffer, command, vc.id) ApplyPadding(buffer, command, nil)
err := vc.ExtendedWriter.WriteBuffer(buffer) err := vc.ExtendedWriter.WriteBuffer(buffer)
if err != nil { if err != nil {
return err return err
} }
if vc.writeDirect { if vc.writeDirect {
vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn)
time.Sleep(20 * time.Millisecond) log.Debugln("XTLS Vision direct write start")
//time.Sleep(10 * time.Millisecond)
} }
if buffer2 != nil { if buffer2 != nil {
if vc.writeDirect { if vc.writeDirect {
@ -237,18 +251,19 @@ func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
} }
vc.writeFilterApplicationData = false vc.writeFilterApplicationData = false
} }
ApplyPadding(buffer2, command, vc.id) ApplyPadding(buffer2, command, nil)
err = vc.ExtendedWriter.WriteBuffer(buffer2) err = vc.ExtendedWriter.WriteBuffer(buffer2)
if vc.writeDirect { if vc.writeDirect {
vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn) vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn)
time.Sleep(20 * time.Millisecond) log.Debugln("XTLS Vision direct write start")
//time.Sleep(10 * time.Millisecond)
} }
} }
return err return err
} }
if vc.writeDirect { /*if vc.writeDirect {
log.Debugln("XTLS Vision Direct write, payloadLen=%d", buffer.Len()) log.Debugln("XTLS Vision Direct write, payloadLen=%d", buffer.Len())
} }*/
return vc.ExtendedWriter.WriteBuffer(buffer) return vc.ExtendedWriter.WriteBuffer(buffer)
} }
@ -333,6 +348,9 @@ func (vc *Conn) sendRequest(p []byte) bool {
WriteWithPadding(buffer, p, commandPaddingContinue, vc.id) WriteWithPadding(buffer, p, commandPaddingContinue, vc.id)
} else { } else {
buf.Must(buf.Error(buffer.Write(p))) buf.Must(buf.Error(buffer.Write(p)))
vc.readProcess = false
vc.writeFilterApplicationData = false
vc.packetsToFilter = 0
} }
} }
} else { } else {
@ -353,10 +371,6 @@ func (vc *Conn) sendRequest(p []byte) bool {
if underlying.ConnectionState().Version != utls.VersionTLS13 { if underlying.ConnectionState().Version != utls.VersionTLS13 {
vc.err = ErrNotTLS13 vc.err = ErrNotTLS13
} }
case *tlsC.UConn:
if underlying.ConnectionState().Version != utls.VersionTLS13 {
vc.err = ErrNotTLS13
}
default: default:
vc.err = fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, vc.addons.Flow) vc.err = fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, vc.addons.Flow)
} }
@ -437,8 +451,9 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
return nil, fmt.Errorf("failed to use %s, maybe \"security\" is not \"xtls\"", client.Addons.Flow) return nil, fmt.Errorf("failed to use %s, maybe \"security\" is not \"xtls\"", client.Addons.Flow)
} }
case XRV: case XRV:
c.packetsToFilter = 10 c.packetsToFilter = 6
c.readFilter = true c.readProcess = true
c.readFilterUUID = true
c.writeFilterApplicationData = true c.writeFilterApplicationData = true
c.addons = client.Addons c.addons = client.Addons
var t reflect.Type var t reflect.Type
@ -446,12 +461,12 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
switch underlying := conn.(type) { switch underlying := conn.(type) {
case *gotls.Conn: case *gotls.Conn:
c.Conn = underlying.NetConn() c.Conn = underlying.NetConn()
c.tlsConn = conn c.tlsConn = underlying
t = reflect.TypeOf(underlying).Elem() t = reflect.TypeOf(underlying).Elem()
p = uintptr(unsafe.Pointer(underlying)) p = uintptr(unsafe.Pointer(underlying))
case *utls.UConn: case *utls.UConn:
c.Conn = underlying.NetConn() c.Conn = underlying.NetConn()
c.tlsConn = conn c.tlsConn = underlying
t = reflect.TypeOf(underlying.Conn).Elem() t = reflect.TypeOf(underlying.Conn).Elem()
p = uintptr(unsafe.Pointer(underlying.Conn)) p = uintptr(unsafe.Pointer(underlying.Conn))
case *tlsC.UConn: case *tlsC.UConn:

View File

@ -28,6 +28,9 @@ const (
) )
func (vc *Conn) FilterTLS(p []byte) (index int) { func (vc *Conn) FilterTLS(p []byte) (index int) {
if vc.packetsToFilter <= 0 {
return 0
}
lenP := len(p) lenP := len(p)
vc.packetsToFilter -= 1 vc.packetsToFilter -= 1
if index = bytes.Index(p, tlsServerHandshakeStart); index != -1 { if index = bytes.Index(p, tlsServerHandshakeStart); index != -1 {

View File

@ -13,7 +13,7 @@ import (
) )
const ( const (
paddingHeaderLen = uuid.Size + 1 + 2 + 2 // =21 paddingHeaderLen = 1 + 2 + 2 // =5
commandPaddingContinue byte = 0x00 commandPaddingContinue byte = 0x00
commandPaddingEnd byte = 0x01 commandPaddingEnd byte = 0x01