chore: cleanup code

This commit is contained in:
wwqgtxx
2022-12-22 09:53:11 +08:00
parent 63922f86a2
commit 980454beb2
5 changed files with 45 additions and 98 deletions

View File

@ -17,7 +17,6 @@ import (
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
)
@ -27,8 +26,7 @@ var (
TooManyOpenStreams = errors.New("tuic: too many open streams")
)
type DialFunc func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error)
type DialWithDialerFunc func(ctx context.Context, dialer C.Dialer) (pc net.PacketConn, addr net.Addr, err error)
type DialFunc func(ctx context.Context, dialer C.Dialer) (pc net.PacketConn, addr net.Addr, err error)
type ClientOption struct {
TlsConfig *tls.Config
@ -57,17 +55,17 @@ type clientImpl struct {
udpInputMap sync.Map
// only ready for PoolClient
optionRef any
dialerRef C.Dialer
lastVisited time.Time
}
func (t *clientImpl) getQuicConn(ctx context.Context, dialFn DialFunc, opts ...dialer.Option) (quic.Connection, error) {
func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn DialFunc) (quic.Connection, error) {
t.connMutex.Lock()
defer t.connMutex.Unlock()
if t.quicConn != nil {
return t.quicConn, nil
}
pc, addr, err := dialFn(ctx, opts...)
pc, addr, err := dialFn(ctx, dialer)
if err != nil {
return nil, err
}
@ -242,8 +240,8 @@ func (t *clientImpl) Close() {
}
}
func (t *clientImpl) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) {
quicConn, err := t.getQuicConn(ctx, dialFn, opts...)
func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) {
quicConn, err := t.getQuicConn(ctx, dialer, dialFn)
if err != nil {
return nil, err
}
@ -340,8 +338,8 @@ func (conn *earlyConn) Read(b []byte) (n int, err error) {
return conn.BufferedConn.Read(b)
}
func (t *clientImpl) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) {
quicConn, err := t.getQuicConn(ctx, dialFn, opts...)
func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) {
quicConn, err := t.getQuicConn(ctx, dialer, dialFn)
if err != nil {
return nil, err
}
@ -385,16 +383,16 @@ type Client struct {
*clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner
}
func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) {
conn, err := t.clientImpl.DialContext(ctx, metadata, dialFn, opts...)
func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) {
conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn)
if err != nil {
return nil, err
}
return N.NewRefConn(conn, t), err
}
func (t *Client) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) {
pc, err := t.clientImpl.ListenPacketContext(ctx, metadata, dialFn, opts...)
func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) {
pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn)
if err != nil {
return nil, err
}

View File

@ -10,7 +10,6 @@ import (
"github.com/Dreamacro/clash/common/generics/list"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
)
@ -25,7 +24,7 @@ type PoolClient struct {
*ClientOption
newClientOption *ClientOption
dialResultMap map[any]dialResult
dialResultMap map[C.Dialer]dialResult
dialResultMutex *sync.Mutex
tcpClients *list.List[*Client]
tcpClientsMutex *sync.Mutex
@ -33,14 +32,10 @@ type PoolClient struct {
udpClientsMutex *sync.Mutex
}
func (t *PoolClient) DialContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.Conn, error) {
newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) {
return t.dial(ctx, dialFn, opts...)
}
var o any = *dialer.ApplyOptions(opts...)
conn, err := t.getClient(false, o).DialContext(ctx, metadata, newDialFn, opts...)
func (t *PoolClient) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) {
conn, err := t.getClient(false, dialer).DialContextWithDialer(ctx, metadata, dialer, dialFn)
if errors.Is(err, TooManyOpenStreams) {
conn, err = t.newClient(false, o).DialContext(ctx, metadata, newDialFn, opts...)
conn, err = t.newClient(false, dialer).DialContextWithDialer(ctx, metadata, dialer, dialFn)
}
if err != nil {
return nil, err
@ -48,29 +43,10 @@ func (t *PoolClient) DialContext(ctx context.Context, metadata *C.Metadata, dial
return N.NewRefConn(conn, t), err
}
func (t *PoolClient) DialContextWithDialer(ctx context.Context, d C.Dialer, metadata *C.Metadata, dialFn DialWithDialerFunc) (net.Conn, error) {
newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) {
return dialFn(ctx, d)
}
var o any = d
conn, err := t.getClient(false, o).DialContext(ctx, metadata, newDialFn)
func (t *PoolClient) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) {
pc, err := t.getClient(true, dialer).ListenPacketWithDialer(ctx, metadata, dialer, dialFn)
if errors.Is(err, TooManyOpenStreams) {
conn, err = t.newClient(false, o).DialContext(ctx, metadata, newDialFn)
}
if err != nil {
return nil, err
}
return N.NewRefConn(conn, t), err
}
func (t *PoolClient) ListenPacketContext(ctx context.Context, metadata *C.Metadata, dialFn DialFunc, opts ...dialer.Option) (net.PacketConn, error) {
newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) {
return t.dial(ctx, dialFn, opts...)
}
var o any = *dialer.ApplyOptions(opts...)
pc, err := t.getClient(true, o).ListenPacketContext(ctx, metadata, newDialFn, opts...)
if errors.Is(err, TooManyOpenStreams) {
pc, err = t.newClient(true, o).ListenPacketContext(ctx, metadata, newDialFn, opts...)
pc, err = t.newClient(true, dialer).ListenPacketWithDialer(ctx, metadata, dialer, dialFn)
}
if err != nil {
return nil, err
@ -78,32 +54,15 @@ func (t *PoolClient) ListenPacketContext(ctx context.Context, metadata *C.Metada
return N.NewRefPacketConn(pc, t), nil
}
func (t *PoolClient) ListenPacketWithDialer(ctx context.Context, d C.Dialer, metadata *C.Metadata, dialFn DialWithDialerFunc) (net.PacketConn, error) {
newDialFn := func(ctx context.Context, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) {
return dialFn(ctx, d)
}
var o any = d
pc, err := t.getClient(true, o).ListenPacketContext(ctx, metadata, newDialFn)
if errors.Is(err, TooManyOpenStreams) {
pc, err = t.newClient(true, o).ListenPacketContext(ctx, metadata, newDialFn)
}
if err != nil {
return nil, err
}
return N.NewRefPacketConn(pc, t), nil
}
func (t *PoolClient) dial(ctx context.Context, dialFn DialFunc, opts ...dialer.Option) (pc net.PacketConn, addr net.Addr, err error) {
var o any = *dialer.ApplyOptions(opts...)
func (t *PoolClient) dial(ctx context.Context, dialer C.Dialer, dialFn DialFunc) (pc net.PacketConn, addr net.Addr, err error) {
t.dialResultMutex.Lock()
dr, ok := t.dialResultMap[o]
dr, ok := t.dialResultMap[dialer]
t.dialResultMutex.Unlock()
if ok {
return dr.pc, dr.addr, dr.err
}
pc, addr, err = dialFn(ctx, opts...)
pc, addr, err = dialFn(ctx, dialer)
if err != nil {
return nil, nil, err
}
@ -111,7 +70,7 @@ func (t *PoolClient) dial(ctx context.Context, dialFn DialFunc, opts ...dialer.O
dr.pc, dr.addr, dr.err = pc, addr, err
t.dialResultMutex.Lock()
t.dialResultMap[o] = dr
t.dialResultMap[dialer] = dr
t.dialResultMutex.Unlock()
return pc, addr, err
}
@ -128,7 +87,7 @@ func (t *PoolClient) forceClose() {
}
}
func (t *PoolClient) newClient(udp bool, o any) *Client {
func (t *PoolClient) newClient(udp bool, dialer C.Dialer) *Client {
clients := t.tcpClients
clientsMutex := t.tcpClientsMutex
if udp {
@ -140,14 +99,14 @@ func (t *PoolClient) newClient(udp bool, o any) *Client {
defer clientsMutex.Unlock()
client := NewClient(t.newClientOption, udp)
client.optionRef = o
client.dialerRef = dialer
client.lastVisited = time.Now()
clients.PushFront(client)
return client
}
func (t *PoolClient) getClient(udp bool, o any) *Client {
func (t *PoolClient) getClient(udp bool, dialer C.Dialer) *Client {
clients := t.tcpClients
clientsMutex := t.tcpClientsMutex
if udp {
@ -167,7 +126,7 @@ func (t *PoolClient) getClient(udp bool, o any) *Client {
it = next
continue
}
if client.optionRef == o {
if client.dialerRef == dialer {
if bestClient == nil {
bestClient = client
} else {
@ -192,7 +151,7 @@ func (t *PoolClient) getClient(udp bool, o any) *Client {
}
if bestClient == nil {
return t.newClient(udp, o)
return t.newClient(udp, dialer)
} else {
bestClient.lastVisited = time.Now()
return bestClient
@ -202,7 +161,7 @@ func (t *PoolClient) getClient(udp bool, o any) *Client {
func NewPoolClient(clientOption *ClientOption) *PoolClient {
p := &PoolClient{
ClientOption: clientOption,
dialResultMap: make(map[any]dialResult),
dialResultMap: make(map[C.Dialer]dialResult),
dialResultMutex: &sync.Mutex{},
tcpClients: list.New[*Client](),
tcpClientsMutex: &sync.Mutex{},