feat: support tuicV5
This commit is contained in:
454
transport/tuic/v4/client.go
Normal file
454
transport/tuic/v4/client.go
Normal file
@ -0,0 +1,454 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
atomic2 "github.com/Dreamacro/clash/common/atomic"
|
||||
"github.com/Dreamacro/clash/common/buf"
|
||||
N "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/common/pool"
|
||||
C "github.com/Dreamacro/clash/constant"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
"github.com/Dreamacro/clash/transport/tuic/common"
|
||||
|
||||
"github.com/metacubex/quic-go"
|
||||
"github.com/zhangyunhao116/fastrand"
|
||||
)
|
||||
|
||||
type ClientOption struct {
|
||||
TlsConfig *tls.Config
|
||||
QuicConfig *quic.Config
|
||||
Token [32]byte
|
||||
UdpRelayMode string
|
||||
CongestionController string
|
||||
ReduceRtt bool
|
||||
RequestTimeout time.Duration
|
||||
MaxUdpRelayPacketSize int
|
||||
FastOpen bool
|
||||
MaxOpenStreams int64
|
||||
}
|
||||
|
||||
type clientImpl struct {
|
||||
*ClientOption
|
||||
udp bool
|
||||
|
||||
quicConn quic.Connection
|
||||
connMutex sync.Mutex
|
||||
|
||||
openStreams atomic.Int64
|
||||
closed atomic.Bool
|
||||
|
||||
udpInputMap sync.Map
|
||||
|
||||
// only ready for PoolClient
|
||||
dialerRef C.Dialer
|
||||
lastVisited atomic2.TypedValue[time.Time]
|
||||
}
|
||||
|
||||
func (t *clientImpl) OpenStreams() int64 {
|
||||
return t.openStreams.Load()
|
||||
}
|
||||
|
||||
func (t *clientImpl) DialerRef() C.Dialer {
|
||||
return t.dialerRef
|
||||
}
|
||||
|
||||
func (t *clientImpl) LastVisited() time.Time {
|
||||
return t.lastVisited.Load()
|
||||
}
|
||||
|
||||
func (t *clientImpl) SetLastVisited(last time.Time) {
|
||||
t.lastVisited.Store(last)
|
||||
}
|
||||
|
||||
func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn common.DialFunc) (quic.Connection, error) {
|
||||
t.connMutex.Lock()
|
||||
defer t.connMutex.Unlock()
|
||||
if t.quicConn != nil {
|
||||
return t.quicConn, nil
|
||||
}
|
||||
transport, addr, err := dialFn(ctx, dialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var quicConn quic.Connection
|
||||
if t.ReduceRtt {
|
||||
quicConn, err = transport.DialEarly(ctx, addr, t.TlsConfig, t.QuicConfig)
|
||||
} else {
|
||||
quicConn, err = transport.Dial(ctx, addr, t.TlsConfig, t.QuicConfig)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
common.SetCongestionController(quicConn, t.CongestionController)
|
||||
|
||||
go func() {
|
||||
_ = t.sendAuthentication(quicConn)
|
||||
}()
|
||||
|
||||
if t.udp {
|
||||
go func() {
|
||||
_ = t.parseUDP(quicConn)
|
||||
}()
|
||||
}
|
||||
|
||||
t.quicConn = quicConn
|
||||
t.openStreams.Store(0)
|
||||
return quicConn, nil
|
||||
}
|
||||
|
||||
func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) {
|
||||
defer func() {
|
||||
t.deferQuicConn(quicConn, err)
|
||||
}()
|
||||
stream, err := quicConn.OpenUniStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf := pool.GetBuffer()
|
||||
defer pool.PutBuffer(buf)
|
||||
err = NewAuthenticate(t.Token).WriteTo(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = buf.WriteTo(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = stream.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) {
|
||||
defer func() {
|
||||
t.deferQuicConn(quicConn, err)
|
||||
}()
|
||||
switch t.UdpRelayMode {
|
||||
case "quic":
|
||||
for {
|
||||
var stream quic.ReceiveStream
|
||||
stream, err = quicConn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() (err error) {
|
||||
var assocId uint32
|
||||
defer func() {
|
||||
t.deferQuicConn(quicConn, err)
|
||||
if err != nil && assocId != 0 {
|
||||
if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok {
|
||||
if conn, ok := val.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
stream.CancelRead(0)
|
||||
}()
|
||||
reader := bufio.NewReader(stream)
|
||||
packet, err := ReadPacket(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
assocId = packet.ASSOC_ID
|
||||
if val, ok := t.udpInputMap.Load(assocId); ok {
|
||||
if conn, ok := val.(net.Conn); ok {
|
||||
writer := bufio.NewWriterSize(conn, packet.BytesLen())
|
||||
_ = packet.WriteTo(writer)
|
||||
_ = writer.Flush()
|
||||
}
|
||||
}
|
||||
return
|
||||
}()
|
||||
}
|
||||
default: // native
|
||||
for {
|
||||
var message []byte
|
||||
message, err = quicConn.ReceiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() (err error) {
|
||||
var assocId uint32
|
||||
defer func() {
|
||||
t.deferQuicConn(quicConn, err)
|
||||
if err != nil && assocId != 0 {
|
||||
if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok {
|
||||
if conn, ok := val.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
buffer := bytes.NewBuffer(message)
|
||||
packet, err := ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
assocId = packet.ASSOC_ID
|
||||
if val, ok := t.udpInputMap.Load(assocId); ok {
|
||||
if conn, ok := val.(net.Conn); ok {
|
||||
_, _ = conn.Write(message)
|
||||
}
|
||||
}
|
||||
return
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) {
|
||||
var netError net.Error
|
||||
if err != nil && errors.As(err, &netError) {
|
||||
t.forceClose(quicConn, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *clientImpl) forceClose(quicConn quic.Connection, err error) {
|
||||
t.connMutex.Lock()
|
||||
defer t.connMutex.Unlock()
|
||||
if quicConn == nil {
|
||||
quicConn = t.quicConn
|
||||
}
|
||||
if quicConn != nil {
|
||||
if quicConn == t.quicConn {
|
||||
t.quicConn = nil
|
||||
}
|
||||
}
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = err.Error()
|
||||
}
|
||||
if quicConn != nil {
|
||||
_ = quicConn.CloseWithError(ProtocolError, errStr)
|
||||
}
|
||||
udpInputMap := &t.udpInputMap
|
||||
udpInputMap.Range(func(key, value any) bool {
|
||||
if conn, ok := value.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
}
|
||||
udpInputMap.Delete(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (t *clientImpl) Close() {
|
||||
t.closed.Store(true)
|
||||
if t.openStreams.Load() == 0 {
|
||||
t.forceClose(nil, common.ClientClosed)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) {
|
||||
quicConn, err := t.getQuicConn(ctx, dialer, dialFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
openStreams := t.openStreams.Add(1)
|
||||
if openStreams >= t.MaxOpenStreams {
|
||||
t.openStreams.Add(-1)
|
||||
return nil, common.TooManyOpenStreams
|
||||
}
|
||||
stream, err := func() (stream net.Conn, err error) {
|
||||
defer func() {
|
||||
t.deferQuicConn(quicConn, err)
|
||||
}()
|
||||
buf := pool.GetBuffer()
|
||||
defer pool.PutBuffer(buf)
|
||||
err = NewConnect(NewAddress(metadata)).WriteTo(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quicStream, err := quicConn.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream = common.NewQuicStreamConn(
|
||||
quicStream,
|
||||
quicConn.LocalAddr(),
|
||||
quicConn.RemoteAddr(),
|
||||
func() {
|
||||
time.AfterFunc(C.DefaultTCPTimeout, func() {
|
||||
openStreams := t.openStreams.Add(-1)
|
||||
if openStreams == 0 && t.closed.Load() {
|
||||
t.forceClose(quicConn, common.ClientClosed)
|
||||
}
|
||||
})
|
||||
},
|
||||
)
|
||||
_, err = buf.WriteTo(stream)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
return stream, err
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bufConn := N.NewBufferedConn(stream)
|
||||
conn := &earlyConn{ExtendedConn: bufConn, bufConn: bufConn, RequestTimeout: t.RequestTimeout}
|
||||
if !t.FastOpen {
|
||||
err = conn.Response()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type earlyConn struct {
|
||||
N.ExtendedConn // only expose standard N.ExtendedConn function to outside
|
||||
bufConn *N.BufferedConn
|
||||
resOnce sync.Once
|
||||
resErr error
|
||||
|
||||
RequestTimeout time.Duration
|
||||
}
|
||||
|
||||
func (conn *earlyConn) response() error {
|
||||
if conn.RequestTimeout > 0 {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(conn.RequestTimeout))
|
||||
}
|
||||
response, err := ReadResponse(conn.bufConn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
if response.IsFailed() {
|
||||
_ = conn.Close()
|
||||
return errors.New("connect failed")
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *earlyConn) Response() error {
|
||||
conn.resOnce.Do(func() {
|
||||
conn.resErr = conn.response()
|
||||
})
|
||||
return conn.resErr
|
||||
}
|
||||
|
||||
func (conn *earlyConn) Read(b []byte) (n int, err error) {
|
||||
err = conn.Response()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return conn.bufConn.Read(b)
|
||||
}
|
||||
|
||||
func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) {
|
||||
err = conn.Response()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.bufConn.ReadBuffer(buffer)
|
||||
}
|
||||
|
||||
func (conn *earlyConn) Upstream() any {
|
||||
return conn.bufConn
|
||||
}
|
||||
|
||||
func (conn *earlyConn) ReaderReplaceable() bool {
|
||||
return atomic.LoadUint32((*uint32)(unsafe.Pointer(&conn.resOnce))) == 1 && conn.resErr == nil
|
||||
}
|
||||
|
||||
func (conn *earlyConn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) {
|
||||
quicConn, err := t.getQuicConn(ctx, dialer, dialFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
openStreams := t.openStreams.Add(1)
|
||||
if openStreams >= t.MaxOpenStreams {
|
||||
t.openStreams.Add(-1)
|
||||
return nil, common.TooManyOpenStreams
|
||||
}
|
||||
|
||||
pipe1, pipe2 := net.Pipe()
|
||||
var connId uint32
|
||||
for {
|
||||
connId = fastrand.Uint32()
|
||||
_, loaded := t.udpInputMap.LoadOrStore(connId, pipe1)
|
||||
if !loaded {
|
||||
break
|
||||
}
|
||||
}
|
||||
pc := &quicStreamPacketConn{
|
||||
connId: connId,
|
||||
quicConn: quicConn,
|
||||
inputConn: N.NewBufferedConn(pipe2),
|
||||
udpRelayMode: t.UdpRelayMode,
|
||||
maxUdpRelayPacketSize: t.MaxUdpRelayPacketSize,
|
||||
deferQuicConnFn: t.deferQuicConn,
|
||||
closeDeferFn: func() {
|
||||
t.udpInputMap.Delete(connId)
|
||||
time.AfterFunc(C.DefaultUDPTimeout, func() {
|
||||
openStreams := t.openStreams.Add(-1)
|
||||
if openStreams == 0 && t.closed.Load() {
|
||||
t.forceClose(quicConn, common.ClientClosed)
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
return pc, nil
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
*clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner
|
||||
}
|
||||
|
||||
func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) {
|
||||
conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return N.NewRefConn(conn, t), err
|
||||
}
|
||||
|
||||
func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) {
|
||||
pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return N.NewRefPacketConn(pc, t), nil
|
||||
}
|
||||
|
||||
func (t *Client) forceClose() {
|
||||
t.clientImpl.forceClose(nil, common.ClientClosed)
|
||||
}
|
||||
|
||||
func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client {
|
||||
ci := &clientImpl{
|
||||
ClientOption: clientOption,
|
||||
udp: udp,
|
||||
dialerRef: dialerRef,
|
||||
}
|
||||
c := &Client{ci}
|
||||
runtime.SetFinalizer(c, closeClient)
|
||||
log.Debugln("New TuicV4 Client at %p", c)
|
||||
return c
|
||||
}
|
||||
|
||||
func closeClient(client *Client) {
|
||||
log.Debugln("Close TuicV4 Client at %p", client)
|
||||
client.forceClose()
|
||||
}
|
Reference in New Issue
Block a user