This repository has been archived on 2024-09-06. You can view files and clone it, but cannot push or open issues or pull requests.
2023-02-24 20:46:59 +08:00

393 lines
8.9 KiB
Go

package vless
import (
"bytes"
gotls "crypto/tls"
"encoding/binary"
"errors"
"fmt"
tlsC "github.com/Dreamacro/clash/component/tls"
"github.com/Dreamacro/clash/log"
utls "github.com/refraction-networking/utls"
"github.com/sagernet/sing/common/network"
"io"
"net"
"reflect"
"sync"
"time"
"unsafe"
"github.com/Dreamacro/clash/common/buf"
N "github.com/Dreamacro/clash/common/net"
"github.com/gofrs/uuid"
xtls "github.com/xtls/go"
"google.golang.org/protobuf/proto"
)
type Conn struct {
network.ExtendedWriter
network.ExtendedReader
net.Conn
dst *DstAddr
id *uuid.UUID
addons *Addons
received bool
handshake chan struct{}
handshakeMutex sync.Mutex
err error
tlsConn net.Conn
input *bytes.Reader
rawInput *bytes.Buffer
packetsToFilter int
isTLS bool
isTLS12orAbove bool
enableXTLS bool
cipher uint16
remainingServerHello uint16
readRemainingContent uint16
readRemainingPadding uint16
readFilterUUID bool
readDirect bool
writeFilterApplicationData bool
writeDirect bool
}
func (vc *Conn) Read(b []byte) (int, error) {
if vc.received {
return vc.ExtendedReader.Read(b)
}
if err := vc.recvResponse(); err != nil {
return 0, err
}
vc.received = true
return vc.Read(b)
}
func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
if vc.received {
return vc.ExtendedReader.ReadBuffer(buffer)
}
if err := vc.recvResponse(); err != nil {
return err
}
vc.received = true
return vc.ReadBuffer(buffer)
}
func (vc *Conn) Write(p []byte) (int, error) {
select {
case <-vc.handshake:
default:
if vc.sendRequest(p) {
if vc.err != nil {
return 0, vc.err
}
return len(p), vc.err
}
if vc.err != nil {
return 0, vc.err
}
}
if vc.writeFilterApplicationData {
_buffer := buf.StackNew()
defer buf.KeepAlive(_buffer)
buffer := buf.Dup(_buffer)
defer buffer.Release()
buffer.Write(p)
err := vc.WriteBuffer(buffer)
if err != nil {
return 0, err
}
return len(p), nil
}
return vc.ExtendedWriter.Write(p)
}
func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
select {
case <-vc.handshake:
default:
if vc.sendRequest(buffer.Bytes()) {
return vc.err
}
if vc.err != nil {
return vc.err
}
}
if vc.writeFilterApplicationData && vc.isTLS {
buffer2 := ReshapeBuffer(buffer)
defer buffer2.Release()
if buffer.Len() > 6 && bytes.Equal(buffer.To(3), tlsApplicationDataStart) {
command := commandPaddingEnd
if vc.enableXTLS {
command = commandPaddingDirect
vc.writeDirect = true
}
vc.writeFilterApplicationData = false
ApplyPadding(buffer, command, vc.id)
}
err := vc.ExtendedWriter.WriteBuffer(buffer)
if err != nil {
return err
}
if vc.writeDirect {
vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn)
time.Sleep(5 * time.Millisecond)
}
if buffer2 != nil {
if vc.writeDirect {
return vc.ExtendedWriter.WriteBuffer(buffer2)
}
if buffer2.Len() > 6 && bytes.Equal(buffer2.To(3), tlsApplicationDataStart) {
command := commandPaddingEnd
if vc.enableXTLS {
command = commandPaddingDirect
vc.writeDirect = true
}
vc.writeFilterApplicationData = false
ApplyPadding(buffer2, command, vc.id)
}
err = vc.ExtendedWriter.WriteBuffer(buffer2)
}
if vc.writeDirect {
vc.ExtendedWriter = N.NewExtendedWriter(vc.Conn)
time.Sleep(5 * time.Millisecond)
}
return err
}
return vc.ExtendedWriter.WriteBuffer(buffer)
}
func (vc *Conn) sendRequest(p []byte) bool {
vc.handshakeMutex.Lock()
defer vc.handshakeMutex.Unlock()
select {
case <-vc.handshake:
// The handshake has been completed before.
// So return false to remind the caller.
return false
default:
}
defer close(vc.handshake)
var addonsBytes []byte
if vc.addons != nil {
addonsBytes, vc.err = proto.Marshal(vc.addons)
if vc.err != nil {
return true
}
}
isVision := vc.IsXTLSVisionEnabled()
var buffer *buf.Buffer
if isVision {
_buffer := buf.StackNew()
defer buf.KeepAlive(_buffer)
buffer = buf.Dup(_buffer)
defer buffer.Release()
} else {
requestLen := 1 // protocol version
requestLen += 16 // UUID
requestLen += 1 // addons length
requestLen += len(addonsBytes)
requestLen += 1 // command
if !vc.dst.Mux {
requestLen += 2 // port
requestLen += 1 // addr type
requestLen += len(vc.dst.Addr)
}
requestLen += len(p)
_buffer := buf.StackNewSize(requestLen)
defer buf.KeepAlive(_buffer)
buffer = buf.Dup(_buffer)
defer buffer.Release()
}
buf.Must(
buffer.WriteByte(Version), // protocol version
buf.Error(buffer.Write(vc.id.Bytes())), // 16 bytes of uuid
buffer.WriteByte(byte(len(addonsBytes))),
buf.Error(buffer.Write(addonsBytes)),
)
if vc.dst.Mux {
buf.Must(buffer.WriteByte(CommandMux))
} else {
if vc.dst.UDP {
buf.Must(buffer.WriteByte(CommandUDP))
} else {
buf.Must(buffer.WriteByte(CommandTCP))
}
binary.BigEndian.PutUint16(buffer.Extend(2), vc.dst.Port)
buf.Must(
buffer.WriteByte(vc.dst.AddrType),
buf.Error(buffer.Write(vc.dst.Addr)),
)
}
if isVision && !vc.dst.UDP && !vc.dst.Mux {
if len(p) == 0 {
vc.packetsToFilter = 0
vc.writeFilterApplicationData = false
WriteWithPadding(buffer, nil, commandPaddingEnd, vc.id)
} else {
vc.FilterTLS(p)
if vc.isTLS {
WriteWithPadding(buffer, p, commandPaddingContinue, vc.id)
} else {
buf.Must(buf.Error(buffer.Write(p)))
}
}
} else {
buf.Must(buf.Error(buffer.Write(p)))
}
_, vc.err = vc.ExtendedWriter.Write(buffer.Bytes())
if vc.err != nil {
return true
}
if isVision {
switch underlying := vc.tlsConn.(type) {
case *gotls.Conn:
if underlying.ConnectionState().Version != gotls.VersionTLS13 {
vc.err = ErrNotTLS13
}
case *utls.UConn:
if underlying.ConnectionState().Version != utls.VersionTLS13 {
vc.err = ErrNotTLS13
}
case *tlsC.UConn:
if underlying.ConnectionState().Version != utls.VersionTLS13 {
vc.err = ErrNotTLS13
}
default:
vc.err = fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, vc.addons.Flow)
}
vc.tlsConn = nil
}
return true
}
func (vc *Conn) recvResponse() error {
var buf [1]byte
_, vc.err = io.ReadFull(vc.ExtendedReader, buf[:])
if vc.err != nil {
return vc.err
}
if buf[0] != Version {
return errors.New("unexpected response version")
}
_, vc.err = io.ReadFull(vc.ExtendedReader, buf[:])
if vc.err != nil {
return vc.err
}
length := int64(buf[0])
if length != 0 { // addon data length > 0
io.CopyN(io.Discard, vc.ExtendedReader, length) // just discard
}
return nil
}
func (vc *Conn) FrontHeadroom() int {
if vc.IsXTLSVisionEnabled() {
return paddingHeaderLen
}
return 0
}
func (vc *Conn) Upstream() any {
return vc.Conn
}
func (vc *Conn) IsXTLSVisionEnabled() bool {
return vc.addons != nil && vc.addons.Flow == XRV
}
// newConn return a Conn instance
func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
c := &Conn{
ExtendedReader: N.NewExtendedReader(conn),
ExtendedWriter: N.NewExtendedWriter(conn),
Conn: conn,
id: client.uuid,
dst: dst,
handshake: make(chan struct{}),
}
if !dst.UDP && client.Addons != nil {
switch client.Addons.Flow {
case XRO, XRD, XRS:
if xtlsConn, ok := conn.(*xtls.Conn); ok {
xtlsConn.RPRX = true
xtlsConn.SHOW = client.XTLSShow
xtlsConn.MARK = "XTLS"
if client.Addons.Flow == XRS {
client.Addons.Flow = XRD
}
if client.Addons.Flow == XRD {
xtlsConn.DirectMode = true
}
c.addons = client.Addons
} else {
return nil, fmt.Errorf("failed to use %s, maybe \"security\" is not \"xtls\"", client.Addons.Flow)
}
case XRV:
c.packetsToFilter = 2048
c.writeFilterApplicationData = true
var t reflect.Type
var p uintptr
switch underlying := conn.(type) {
case *gotls.Conn:
c.Conn = underlying.NetConn()
c.tlsConn = conn
t = reflect.TypeOf(underlying).Elem()
p = uintptr(unsafe.Pointer(underlying))
case *utls.UConn:
c.Conn = underlying.NetConn()
c.tlsConn = conn
t = reflect.TypeOf(underlying.Conn).Elem()
p = uintptr(unsafe.Pointer(underlying.Conn))
case *tlsC.UConn:
c.Conn = underlying.NetConn()
c.tlsConn = underlying.UConn
t = reflect.TypeOf(underlying.Conn).Elem()
p = uintptr(unsafe.Pointer(underlying.Conn))
default:
return nil, fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, client.Addons.Flow)
}
i, _ := t.FieldByName("input")
r, _ := t.FieldByName("rawInput")
c.input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset))
c.rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset))
if _, ok := c.Conn.(*net.TCPConn); !ok {
log.Debugln("XTLS underlying conn is not *net.TCPConn, got", reflect.TypeOf(conn).Name())
}
}
}
go func() {
select {
case <-c.handshake:
case <-time.After(200 * time.Millisecond):
c.sendRequest(nil)
}
}()
return c, nil
}