fix: remove cyclic dependent to make tuic's Finalizer work

This commit is contained in:
wwqgtxx
2022-12-02 16:56:17 +08:00
parent bc5ab3120f
commit 0aefa3be85
7 changed files with 209 additions and 159 deletions

View File

@ -19,6 +19,7 @@ import (
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
)
var (
@ -26,9 +27,9 @@ var (
TooManyOpenStreams = errors.New("tuic: too many open streams")
)
type ClientOption struct {
DialFn func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error)
type DialFunc func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error)
type ClientOption struct {
TlsConfig *tls.Config
QuicConfig *quic.Config
Host string
@ -42,7 +43,7 @@ type ClientOption struct {
MaxOpenStreams int64
}
type Client struct {
type clientImpl struct {
*ClientOption
udp bool
@ -55,18 +56,17 @@ type Client struct {
udpInputMap sync.Map
// only ready for PoolClient
poolRef *PoolClient
optionRef any
lastVisited time.Time
}
func (t *Client) getQuicConn(ctx context.Context) (quic.Connection, error) {
func (t *clientImpl) getQuicConn(ctx context.Context, dialFn DialFunc, opts ...dialer.Option) (quic.Connection, error) {
t.connMutex.Lock()
defer t.connMutex.Unlock()
if t.quicConn != nil {
return t.quicConn, nil
}
pc, addr, err := t.DialFn(ctx)
pc, addr, err := dialFn(ctx, opts...)
if err != nil {
return nil, err
}
@ -97,7 +97,7 @@ func (t *Client) getQuicConn(ctx context.Context) (quic.Connection, error) {
return quicConn, nil
}
func (t *Client) sendAuthentication(quicConn quic.Connection) (err error) {
func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) {
defer func() {
t.deferQuicConn(quicConn, err)
}()
@ -122,7 +122,7 @@ func (t *Client) sendAuthentication(quicConn quic.Connection) (err error) {
return nil
}
func (t *Client) parseUDP(quicConn quic.Connection) (err error) {
func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) {
defer func() {
t.deferQuicConn(quicConn, err)
}()
@ -199,45 +199,50 @@ func (t *Client) parseUDP(quicConn quic.Connection) (err error) {
}
}
func (t *Client) deferQuicConn(quicConn quic.Connection, err error) {
func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) {
var netError net.Error
if err != nil && errors.As(err, &netError) {
t.connMutex.Lock()
defer t.connMutex.Unlock()
if t.quicConn == quicConn {
t.forceClose(err, true)
t.forceClose(quicConn, err)
}
}
func (t *clientImpl) forceClose(quicConn quic.Connection, err error) {
t.connMutex.Lock()
defer t.connMutex.Unlock()
if quicConn == nil {
quicConn = t.quicConn
}
if quicConn != nil {
if quicConn == t.quicConn {
t.quicConn = nil
}
}
}
func (t *Client) forceClose(err error, locked bool) {
if !locked {
t.connMutex.Lock()
defer t.connMutex.Unlock()
errStr := ""
if err != nil {
errStr = err.Error()
}
quicConn := t.quicConn
if quicConn != nil {
_ = quicConn.CloseWithError(ProtocolError, err.Error())
t.udpInputMap.Range(func(key, value any) bool {
if conn, ok := value.(net.Conn); ok {
_ = conn.Close()
}
t.udpInputMap.Delete(key)
return true
})
t.quicConn = nil
_ = quicConn.CloseWithError(ProtocolError, errStr)
}
udpInputMap := &t.udpInputMap
udpInputMap.Range(func(key, value any) bool {
if conn, ok := value.(net.Conn); ok {
_ = conn.Close()
}
udpInputMap.Delete(key)
return true
})
}
func (t *Client) Close() {
func (t *clientImpl) Close() {
t.closed.Store(true)
if t.openStreams.Load() == 0 {
t.forceClose(ClientClosed, false)
t.forceClose(nil, ClientClosed)
}
}
func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) {
quicConn, err := t.getQuicConn(ctx)
func (t *clientImpl) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) {
quicConn, err := t.getQuicConn(ctx, dialFn, opts...)
if err != nil {
return nil, err
}
@ -264,12 +269,11 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata) (net.Con
Stream: quicStream,
lAddr: quicConn.LocalAddr(),
rAddr: quicConn.RemoteAddr(),
ref: t,
closeDeferFn: func() {
time.AfterFunc(C.DefaultTCPTimeout, func() {
openStreams := t.openStreams.Add(-1)
if openStreams == 0 && t.closed.Load() {
t.forceClose(ClientClosed, false)
t.forceClose(quicConn, ClientClosed)
}
})
},
@ -335,8 +339,8 @@ func (conn *earlyConn) Read(b []byte) (n int, err error) {
return conn.BufferedConn.Read(b)
}
func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (net.PacketConn, error) {
quicConn, err := t.getQuicConn(ctx)
func (t *clientImpl) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) {
quicConn, err := t.getQuicConn(ctx, dialFn, opts...)
if err != nil {
return nil, err
}
@ -362,14 +366,13 @@ func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
inputConn: N.NewBufferedConn(pipe2),
udpRelayMode: t.UdpRelayMode,
maxUdpRelayPacketSize: t.MaxUdpRelayPacketSize,
ref: t,
deferQuicConnFn: t.deferQuicConn,
closeDeferFn: func() {
t.udpInputMap.Delete(connId)
time.AfterFunc(C.DefaultUDPTimeout, func() {
openStreams := t.openStreams.Add(-1)
if openStreams == 0 && t.closed.Load() {
t.forceClose(ClientClosed, false)
t.forceClose(quicConn, ClientClosed)
}
})
},
@ -377,15 +380,42 @@ func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
return pc, nil
}
type Client struct {
*clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner
}
func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) {
conn, err := t.clientImpl.DialContext(ctx, metadata, dialFn, opts...)
if err != nil {
return nil, err
}
return N.NewRefConn(conn, t), err
}
func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) {
pc, err := t.clientImpl.ListenPacketContext(ctx, metadata, dialFn, opts...)
if err != nil {
return nil, err
}
return N.NewRefPacketConn(pc, t), nil
}
func (t *Client) forceClose() {
t.clientImpl.forceClose(nil, ClientClosed)
}
func NewClient(clientOption *ClientOption, udp bool) *Client {
c := &Client{
ci := &clientImpl{
ClientOption: clientOption,
udp: udp,
}
c := &Client{ci}
runtime.SetFinalizer(c, closeClient)
log.Debugln("New Tuic Client at %p", c)
return c
}
func closeClient(client *Client) {
client.forceClose(ClientClosed, false)
log.Debugln("Close Tuic Client at %p", client)
client.forceClose()
}