diff --git a/adapter/outboundgroup/fallback.go b/adapter/outboundgroup/fallback.go index 9af49382..cdd2c79a 100644 --- a/adapter/outboundgroup/fallback.go +++ b/adapter/outboundgroup/fallback.go @@ -14,7 +14,7 @@ import ( type Fallback struct { *outbound.Base disableUDP bool - single *singledo.Single + single *singledo.Single[[]C.Proxy] providers []provider.ProxyProvider } @@ -73,11 +73,11 @@ func (f *Fallback) Unwrap(metadata *C.Metadata) C.Proxy { } func (f *Fallback) proxies(touch bool) []C.Proxy { - elm, _, _ := f.single.Do(func() (any, error) { + elm, _, _ := f.single.Do(func() ([]C.Proxy, error) { return getProvidersProxies(f.providers, touch), nil }) - return elm.([]C.Proxy) + return elm } func (f *Fallback) findAliveProxy(touch bool) C.Proxy { @@ -99,7 +99,7 @@ func NewFallback(option *GroupCommonOption, providers []provider.ProxyProvider) Interface: option.Interface, RoutingMark: option.RoutingMark, }), - single: singledo.NewSingle(defaultGetProxiesDuration), + single: singledo.NewSingle[[]C.Proxy](defaultGetProxiesDuration), providers: providers, disableUDP: option.DisableUDP, } diff --git a/adapter/outboundgroup/loadbalance.go b/adapter/outboundgroup/loadbalance.go index d0ae4fe3..7c9f20d7 100644 --- a/adapter/outboundgroup/loadbalance.go +++ b/adapter/outboundgroup/loadbalance.go @@ -22,7 +22,7 @@ type strategyFn = func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy type LoadBalance struct { *outbound.Base disableUDP bool - single *singledo.Single + single *singledo.Single[[]C.Proxy] providers []provider.ProxyProvider strategyFn strategyFn } @@ -140,11 +140,11 @@ func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy { } func (lb *LoadBalance) proxies(touch bool) []C.Proxy { - elm, _, _ := lb.single.Do(func() (any, error) { + elm, _, _ := lb.single.Do(func() ([]C.Proxy, error) { return getProvidersProxies(lb.providers, touch), nil }) - return elm.([]C.Proxy) + return elm } // MarshalJSON implements C.ProxyAdapter @@ -176,7 +176,7 @@ func NewLoadBalance(option *GroupCommonOption, providers []provider.ProxyProvide Interface: option.Interface, RoutingMark: option.RoutingMark, }), - single: singledo.NewSingle(defaultGetProxiesDuration), + single: singledo.NewSingle[[]C.Proxy](defaultGetProxiesDuration), providers: providers, strategyFn: strategyFn, disableUDP: option.DisableUDP, diff --git a/adapter/outboundgroup/relay.go b/adapter/outboundgroup/relay.go index 03e6982a..7c0ca181 100644 --- a/adapter/outboundgroup/relay.go +++ b/adapter/outboundgroup/relay.go @@ -14,7 +14,7 @@ import ( type Relay struct { *outbound.Base - single *singledo.Single + single *singledo.Single[[]C.Proxy] providers []provider.ProxyProvider } @@ -79,11 +79,11 @@ func (r *Relay) MarshalJSON() ([]byte, error) { } func (r *Relay) rawProxies(touch bool) []C.Proxy { - elm, _, _ := r.single.Do(func() (any, error) { + elm, _, _ := r.single.Do(func() ([]C.Proxy, error) { return getProvidersProxies(r.providers, touch), nil }) - return elm.([]C.Proxy) + return elm } func (r *Relay) proxies(metadata *C.Metadata, touch bool) []C.Proxy { @@ -108,7 +108,7 @@ func NewRelay(option *GroupCommonOption, providers []provider.ProxyProvider) *Re Interface: option.Interface, RoutingMark: option.RoutingMark, }), - single: singledo.NewSingle(defaultGetProxiesDuration), + single: singledo.NewSingle[[]C.Proxy](defaultGetProxiesDuration), providers: providers, } } diff --git a/adapter/outboundgroup/selector.go b/adapter/outboundgroup/selector.go index 3975df77..da71022f 100644 --- a/adapter/outboundgroup/selector.go +++ b/adapter/outboundgroup/selector.go @@ -15,7 +15,7 @@ import ( type Selector struct { *outbound.Base disableUDP bool - single *singledo.Single + single *singledo.Single[C.Proxy] selected string providers []provider.ProxyProvider } @@ -83,7 +83,7 @@ func (s *Selector) Unwrap(metadata *C.Metadata) C.Proxy { } func (s *Selector) selectedProxy(touch bool) C.Proxy { - elm, _, _ := s.single.Do(func() (any, error) { + elm, _, _ := s.single.Do(func() (C.Proxy, error) { proxies := getProvidersProxies(s.providers, touch) for _, proxy := range proxies { if proxy.Name() == s.selected { @@ -94,7 +94,7 @@ func (s *Selector) selectedProxy(touch bool) C.Proxy { return proxies[0], nil }) - return elm.(C.Proxy) + return elm } func NewSelector(option *GroupCommonOption, providers []provider.ProxyProvider) *Selector { @@ -106,7 +106,7 @@ func NewSelector(option *GroupCommonOption, providers []provider.ProxyProvider) Interface: option.Interface, RoutingMark: option.RoutingMark, }), - single: singledo.NewSingle(defaultGetProxiesDuration), + single: singledo.NewSingle[C.Proxy](defaultGetProxiesDuration), providers: providers, selected: selected, disableUDP: option.DisableUDP, diff --git a/adapter/outboundgroup/urltest.go b/adapter/outboundgroup/urltest.go index 61597498..bd507e51 100644 --- a/adapter/outboundgroup/urltest.go +++ b/adapter/outboundgroup/urltest.go @@ -25,8 +25,8 @@ type URLTest struct { tolerance uint16 disableUDP bool fastNode C.Proxy - single *singledo.Single - fastSingle *singledo.Single + single *singledo.Single[[]C.Proxy] + fastSingle *singledo.Single[C.Proxy] providers []provider.ProxyProvider } @@ -58,15 +58,15 @@ func (u *URLTest) Unwrap(metadata *C.Metadata) C.Proxy { } func (u *URLTest) proxies(touch bool) []C.Proxy { - elm, _, _ := u.single.Do(func() (any, error) { + elm, _, _ := u.single.Do(func() ([]C.Proxy, error) { return getProvidersProxies(u.providers, touch), nil }) - return elm.([]C.Proxy) + return elm } func (u *URLTest) fast(touch bool) C.Proxy { - elm, _, _ := u.fastSingle.Do(func() (any, error) { + elm, _, _ := u.fastSingle.Do(func() (C.Proxy, error) { proxies := u.proxies(touch) fast := proxies[0] min := fast.LastDelay() @@ -96,7 +96,7 @@ func (u *URLTest) fast(touch bool) C.Proxy { return u.fastNode, nil }) - return elm.(C.Proxy) + return elm } // SupportUDP implements C.ProxyAdapter @@ -142,8 +142,8 @@ func NewURLTest(option *GroupCommonOption, providers []provider.ProxyProvider, o Interface: option.Interface, RoutingMark: option.RoutingMark, }), - single: singledo.NewSingle(defaultGetProxiesDuration), - fastSingle: singledo.NewSingle(time.Second * 10), + single: singledo.NewSingle[[]C.Proxy](defaultGetProxiesDuration), + fastSingle: singledo.NewSingle[C.Proxy](time.Second * 10), providers: providers, disableUDP: option.DisableUDP, } diff --git a/adapter/provider/healthcheck.go b/adapter/provider/healthcheck.go index bfbaf6b0..430225c4 100644 --- a/adapter/provider/healthcheck.go +++ b/adapter/provider/healthcheck.go @@ -65,14 +65,14 @@ func (hc *HealthCheck) touch() { } func (hc *HealthCheck) check() { - b, _ := batch.New(context.Background(), batch.WithConcurrencyNum(10)) + b, _ := batch.New[bool](context.Background(), batch.WithConcurrencyNum[bool](10)) for _, proxy := range hc.proxies { p := proxy - b.Go(p.Name(), func() (any, error) { + b.Go(p.Name(), func() (bool, error) { ctx, cancel := context.WithTimeout(context.Background(), defaultURLTestTimeout) defer cancel() - p.URLTest(ctx, hc.url) - return nil, nil + _, _ = p.URLTest(ctx, hc.url) + return false, nil }) } b.Wait() diff --git a/common/batch/batch.go b/common/batch/batch.go index f0085215..5a6360e5 100644 --- a/common/batch/batch.go +++ b/common/batch/batch.go @@ -5,10 +5,10 @@ import ( "sync" ) -type Option = func(b *Batch) +type Option[T any] func(b *Batch[T]) -type Result struct { - Value any +type Result[T any] struct { + Value T Err error } @@ -17,8 +17,8 @@ type Error struct { Err error } -func WithConcurrencyNum(n int) Option { - return func(b *Batch) { +func WithConcurrencyNum[T any](n int) Option[T] { + return func(b *Batch[T]) { q := make(chan struct{}, n) for i := 0; i < n; i++ { q <- struct{}{} @@ -28,8 +28,8 @@ func WithConcurrencyNum(n int) Option { } // Batch similar to errgroup, but can control the maximum number of concurrent -type Batch struct { - result map[string]Result +type Batch[T any] struct { + result map[string]Result[T] queue chan struct{} wg sync.WaitGroup mux sync.Mutex @@ -38,7 +38,7 @@ type Batch struct { cancel func() } -func (b *Batch) Go(key string, fn func() (any, error)) { +func (b *Batch[T]) Go(key string, fn func() (T, error)) { b.wg.Add(1) go func() { defer b.wg.Done() @@ -59,14 +59,14 @@ func (b *Batch) Go(key string, fn func() (any, error)) { }) } - ret := Result{value, err} + ret := Result[T]{value, err} b.mux.Lock() defer b.mux.Unlock() b.result[key] = ret }() } -func (b *Batch) Wait() *Error { +func (b *Batch[T]) Wait() *Error { b.wg.Wait() if b.cancel != nil { b.cancel() @@ -74,26 +74,26 @@ func (b *Batch) Wait() *Error { return b.err } -func (b *Batch) WaitAndGetResult() (map[string]Result, *Error) { +func (b *Batch[T]) WaitAndGetResult() (map[string]Result[T], *Error) { err := b.Wait() return b.Result(), err } -func (b *Batch) Result() map[string]Result { +func (b *Batch[T]) Result() map[string]Result[T] { b.mux.Lock() defer b.mux.Unlock() - copy := map[string]Result{} + copyM := map[string]Result[T]{} for k, v := range b.result { - copy[k] = v + copyM[k] = v } - return copy + return copyM } -func New(ctx context.Context, opts ...Option) (*Batch, context.Context) { +func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) { ctx, cancel := context.WithCancel(ctx) - b := &Batch{ - result: map[string]Result{}, + b := &Batch[T]{ + result: map[string]Result[T]{}, } for _, o := range opts { diff --git a/common/batch/batch_test.go b/common/batch/batch_test.go index 4e44158c..73350fd3 100644 --- a/common/batch/batch_test.go +++ b/common/batch/batch_test.go @@ -11,14 +11,14 @@ import ( ) func TestBatch(t *testing.T) { - b, _ := New(context.Background()) + b, _ := New[string](context.Background()) now := time.Now() - b.Go("foo", func() (any, error) { + b.Go("foo", func() (string, error) { time.Sleep(time.Millisecond * 100) return "foo", nil }) - b.Go("bar", func() (any, error) { + b.Go("bar", func() (string, error) { time.Sleep(time.Millisecond * 150) return "bar", nil }) @@ -32,20 +32,20 @@ func TestBatch(t *testing.T) { for k, v := range result { assert.NoError(t, v.Err) - assert.Equal(t, k, v.Value.(string)) + assert.Equal(t, k, v.Value) } } func TestBatchWithConcurrencyNum(t *testing.T) { - b, _ := New( + b, _ := New[string]( context.Background(), - WithConcurrencyNum(3), + WithConcurrencyNum[string](3), ) now := time.Now() for i := 0; i < 7; i++ { idx := i - b.Go(strconv.Itoa(idx), func() (any, error) { + b.Go(strconv.Itoa(idx), func() (string, error) { time.Sleep(time.Millisecond * 100) return strconv.Itoa(idx), nil }) @@ -57,21 +57,21 @@ func TestBatchWithConcurrencyNum(t *testing.T) { for k, v := range result { assert.NoError(t, v.Err) - assert.Equal(t, k, v.Value.(string)) + assert.Equal(t, k, v.Value) } } func TestBatchContext(t *testing.T) { - b, ctx := New(context.Background()) + b, ctx := New[string](context.Background()) - b.Go("error", func() (any, error) { + b.Go("error", func() (string, error) { time.Sleep(time.Millisecond * 100) - return nil, errors.New("test error") + return "", errors.New("test error") }) - b.Go("ctx", func() (any, error) { + b.Go("ctx", func() (string, error) { <-ctx.Done() - return nil, ctx.Err() + return "", ctx.Err() }) result, err := b.WaitAndGetResult() diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go index 82eca7f4..5fef9445 100644 --- a/common/cache/lrucache.go +++ b/common/cache/lrucache.go @@ -3,19 +3,20 @@ package cache // Modified by https://github.com/die-net/lrucache import ( - "container/list" "sync" "time" + + "github.com/Dreamacro/clash/common/generics/list" ) // Option is part of Functional Options Pattern type Option[K comparable, V any] func(*LruCache[K, V]) // EvictCallback is used to get a callback when a cache entry is evicted -type EvictCallback = func(key any, value any) +type EvictCallback[K comparable, V any] func(key K, value V) // WithEvict set the evict callback -func WithEvict[K comparable, V any](cb EvictCallback) Option[K, V] { +func WithEvict[K comparable, V any](cb EvictCallback[K, V]) Option[K, V] { return func(l *LruCache[K, V]) { l.onEvict = cb } @@ -57,18 +58,18 @@ type LruCache[K comparable, V any] struct { maxAge int64 maxSize int mu sync.Mutex - cache map[any]*list.Element - lru *list.List // Front is least-recent + cache map[K]*list.Element[*entry[K, V]] + lru *list.List[*entry[K, V]] // Front is least-recent updateAgeOnGet bool staleReturn bool - onEvict EvictCallback + onEvict EvictCallback[K, V] } // NewLRUCache creates an LruCache func NewLRUCache[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] { lc := &LruCache[K, V]{ - lru: list.New(), - cache: make(map[any]*list.Element), + lru: list.New[*entry[K, V]](), + cache: make(map[K]*list.Element[*entry[K, V]]), } for _, option := range options { @@ -129,7 +130,7 @@ func (c *LruCache[K, V]) SetWithExpire(key K, value V, expires time.Time) { if le, ok := c.cache[key]; ok { c.lru.MoveToBack(le) - e := le.Value.(*entry[K, V]) + e := le.Value e.value = value e.expires = expires.Unix() } else { @@ -154,11 +155,11 @@ func (c *LruCache[K, V]) CloneTo(n *LruCache[K, V]) { n.mu.Lock() defer n.mu.Unlock() - n.lru = list.New() - n.cache = make(map[any]*list.Element) + n.lru = list.New[*entry[K, V]]() + n.cache = make(map[K]*list.Element[*entry[K, V]]) for e := c.lru.Front(); e != nil; e = e.Next() { - elm := e.Value.(*entry[K, V]) + elm := e.Value n.cache[elm.key] = n.lru.PushBack(elm) } } @@ -172,7 +173,7 @@ func (c *LruCache[K, V]) get(key K) *entry[K, V] { return nil } - if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry[K, V]).expires <= time.Now().Unix() { + if !c.staleReturn && c.maxAge > 0 && le.Value.expires <= time.Now().Unix() { c.deleteElement(le) c.maybeDeleteOldest() @@ -180,7 +181,7 @@ func (c *LruCache[K, V]) get(key K) *entry[K, V] { } c.lru.MoveToBack(le) - el := le.Value.(*entry[K, V]) + el := le.Value if c.maxAge > 0 && c.updateAgeOnGet { el.expires = time.Now().Unix() + c.maxAge } @@ -201,15 +202,15 @@ func (c *LruCache[K, V]) Delete(key K) { func (c *LruCache[K, V]) maybeDeleteOldest() { if !c.staleReturn && c.maxAge > 0 { now := time.Now().Unix() - for le := c.lru.Front(); le != nil && le.Value.(*entry[K, V]).expires <= now; le = c.lru.Front() { + for le := c.lru.Front(); le != nil && le.Value.expires <= now; le = c.lru.Front() { c.deleteElement(le) } } } -func (c *LruCache[K, V]) deleteElement(le *list.Element) { +func (c *LruCache[K, V]) deleteElement(le *list.Element[*entry[K, V]]) { c.lru.Remove(le) - e := le.Value.(*entry[K, V]) + e := le.Value delete(c.cache, e.key) if c.onEvict != nil { c.onEvict(e.key, e.value) @@ -219,7 +220,7 @@ func (c *LruCache[K, V]) deleteElement(le *list.Element) { func (c *LruCache[K, V]) Clear() error { c.mu.Lock() - c.cache = make(map[any]*list.Element) + c.cache = make(map[K]*list.Element[*entry[K, V]]) c.mu.Unlock() return nil diff --git a/common/cache/lrucache_test.go b/common/cache/lrucache_test.go index 487c184e..1a4c68ae 100644 --- a/common/cache/lrucache_test.go +++ b/common/cache/lrucache_test.go @@ -52,18 +52,18 @@ func TestLRUMaxAge(t *testing.T) { // Add one expired entry c.Set("foo", "bar") - c.lru.Back().Value.(*entry[string, string]).expires = now + c.lru.Back().Value.expires = now // Reset c.Set("foo", "bar") - e := c.lru.Back().Value.(*entry[string, string]) + e := c.lru.Back().Value assert.True(t, e.expires >= now) - c.lru.Back().Value.(*entry[string, string]).expires = now + c.lru.Back().Value.expires = now // Set a few and verify expiration times for _, s := range entries { c.Set(s.key, s.value) - e := c.lru.Back().Value.(*entry[string, string]) + e := c.lru.Back().Value assert.True(t, e.expires >= expected && e.expires <= expected+10) } @@ -77,7 +77,7 @@ func TestLRUMaxAge(t *testing.T) { for _, s := range entries { le, ok := c.cache[s.key] if assert.True(t, ok) { - le.Value.(*entry[string, string]).expires = now + le.Value.expires = now } } @@ -95,11 +95,11 @@ func TestLRUpdateOnGet(t *testing.T) { // Add one expired entry c.Set("foo", "bar") - c.lru.Back().Value.(*entry[string, string]).expires = expires + c.lru.Back().Value.expires = expires _, ok := c.Get("foo") assert.True(t, ok) - assert.True(t, c.lru.Back().Value.(*entry[string, string]).expires > expires) + assert.True(t, c.lru.Back().Value.expires > expires) } func TestMaxSize(t *testing.T) { @@ -126,8 +126,8 @@ func TestExist(t *testing.T) { func TestEvict(t *testing.T) { temp := 0 - evict := func(key any, value any) { - temp = key.(int) + value.(int) + evict := func(key int, value int) { + temp = key + value } c := NewLRUCache[int, int](WithEvict[int, int](evict), WithSize[int, int](1)) diff --git a/common/cert/cert.go b/common/cert/cert.go index 3c931665..29bec9de 100644 --- a/common/cert/cert.go +++ b/common/cert/cert.go @@ -11,6 +11,7 @@ import ( "math/big" "net" "os" + "strings" "sync/atomic" "time" ) @@ -38,19 +39,6 @@ type CertsStorage interface { Set(key string, cert *tls.Certificate) } -type CertsCache struct { - certsCache map[string]*tls.Certificate -} - -func (c *CertsCache) Get(key string) (*tls.Certificate, bool) { - v, ok := c.certsCache[key] - return v, ok -} - -func (c *CertsCache) Set(key string, cert *tls.Certificate) { - c.certsCache[key] = cert -} - func NewAuthority(name, organization string, validity time.Duration) (*x509.Certificate, *rsa.PrivateKey, error) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { @@ -100,7 +88,7 @@ func NewAuthority(name, organization string, validity time.Duration) (*x509.Cert return x509c, privateKey, nil } -func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey, storage CertsStorage) (*Config, error) { +func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*Config, error) { roots := x509.NewCertPool() roots.AddCert(ca) @@ -121,10 +109,6 @@ func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey, storage Certs } keyID := h.Sum(nil) - if storage == nil { - storage = &CertsCache{certsCache: make(map[string]*tls.Certificate)} - } - return &Config{ ca: ca, caPrivateKey: caPrivateKey, @@ -132,7 +116,7 @@ func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey, storage Certs keyID: keyID, validity: time.Hour, organization: "Clash", - certsStorage: storage, + certsStorage: NewDomainTrieCertsStorage(), roots: roots, }, nil } @@ -168,14 +152,11 @@ func (c *Config) NewTLSConfigForHost(hostname string) *tls.Config { } func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certificate, error) { - host, _, err := net.SplitHostPort(hostname) - if err == nil { - hostname = host - } - + var leaf *x509.Certificate tlsCertificate, ok := c.certsStorage.Get(hostname) if ok { - if _, err = tlsCertificate.Leaf.Verify(x509.VerifyOptions{ + leaf = tlsCertificate.Leaf + if _, err := leaf.Verify(x509.VerifyOptions{ DNSName: hostname, Roots: c.roots, }); err == nil { @@ -183,12 +164,49 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica } } + var ( + key = hostname + topHost = hostname + wildcardHost = "*." + hostname + dnsNames []string + ) + + if ip := net.ParseIP(hostname); ip != nil { + ips = append(ips, ip) + } else { + parts := strings.Split(hostname, ".") + l := len(parts) + + if leaf != nil { + dnsNames = append(dnsNames, leaf.DNSNames...) + } + + if l > 2 { + topIndex := l - 2 + topHost = strings.Join(parts[topIndex:], ".") + + for i := topIndex; i > 0; i-- { + wildcardHost = "*." + strings.Join(parts[i:], ".") + + if i == topIndex && (len(dnsNames) == 0 || dnsNames[0] != topHost) { + dnsNames = append(dnsNames, topHost, wildcardHost) + } else if !hasDnsNames(dnsNames, wildcardHost) { + dnsNames = append(dnsNames, wildcardHost) + } + } + } else { + dnsNames = append(dnsNames, topHost, wildcardHost) + } + + key = "+." + topHost + } + serial := atomic.AddInt64(¤tSerialNumber, 1) tmpl := &x509.Certificate{ SerialNumber: big.NewInt(serial), Subject: pkix.Name{ - CommonName: hostname, + CommonName: topHost, Organization: []string{c.organization}, }, SubjectKeyId: c.keyID, @@ -197,16 +215,10 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica BasicConstraintsValid: true, NotBefore: time.Now().Add(-c.validity), NotAfter: time.Now().Add(c.validity), + DNSNames: dnsNames, + IPAddresses: ips, } - if ip := net.ParseIP(hostname); ip != nil { - ips = append(ips, ip) - } else { - tmpl.DNSNames = []string{hostname} - } - - tmpl.IPAddresses = ips - raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.privateKey.Public(), c.caPrivateKey) if err != nil { return nil, err @@ -223,7 +235,7 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica Leaf: x509c, } - c.certsStorage.Set(hostname, tlsCertificate) + c.certsStorage.Set(key, tlsCertificate) return tlsCertificate, nil } @@ -280,3 +292,12 @@ func GenerateAndSave(caPath string, caKeyPath string) error { return nil } + +func hasDnsNames(dnsNames []string, hostname string) bool { + for _, name := range dnsNames { + if name == hostname { + return true + } + } + return false +} diff --git a/common/cert/cert_test.go b/common/cert/cert_test.go index 42265613..c237c588 100644 --- a/common/cert/cert_test.go +++ b/common/cert/cert_test.go @@ -18,7 +18,7 @@ func TestCert(t *testing.T) { assert.NotNil(t, ca) assert.NotNil(t, privateKey) - c, err := NewConfig(ca, privateKey, nil) + c, err := NewConfig(ca, privateKey) assert.Nil(t, err) c.SetValidity(20 * time.Hour) @@ -40,27 +40,55 @@ func TestCert(t *testing.T) { x509c := tlsCert.Leaf assert.Equal(t, "example.org", x509c.Subject.CommonName) assert.Nil(t, x509c.VerifyHostname("example.org")) + assert.Nil(t, x509c.VerifyHostname("abc.example.org")) assert.Equal(t, []string{"Test Organization"}, x509c.Subject.Organization) assert.NotNil(t, x509c.SubjectKeyId) assert.True(t, x509c.BasicConstraintsValid) assert.True(t, x509c.KeyUsage&x509.KeyUsageKeyEncipherment == x509.KeyUsageKeyEncipherment) assert.True(t, x509c.KeyUsage&x509.KeyUsageDigitalSignature == x509.KeyUsageDigitalSignature) assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, x509c.ExtKeyUsage) - assert.Equal(t, []string{"example.org"}, x509c.DNSNames) + assert.Equal(t, []string{"example.org", "*.example.org"}, x509c.DNSNames) assert.True(t, x509c.NotBefore.Before(time.Now().Add(-2*time.Hour))) assert.True(t, x509c.NotAfter.After(time.Now().Add(2*time.Hour))) // Check that certificate is cached - tlsCert2, err := c.GetOrCreateCert("example.org") + tlsCert2, err := c.GetOrCreateCert("abc.example.org") assert.Nil(t, err) assert.True(t, tlsCert == tlsCert2) - // Check the certificate for an IP - tlsCertForIP, err := c.GetOrCreateCert("192.168.0.1:443") + // Check that certificate is new + _, _ = c.GetOrCreateCert("a.b.c.d.e.f.g.h.i.j.example.org") + tlsCert3, err := c.GetOrCreateCert("m.k.l.example.org") + x509c = tlsCert3.Leaf assert.Nil(t, err) + assert.False(t, tlsCert == tlsCert3) + assert.Equal(t, []string{"example.org", "*.example.org", "*.j.example.org", "*.i.j.example.org", "*.h.i.j.example.org", "*.g.h.i.j.example.org", "*.f.g.h.i.j.example.org", "*.e.f.g.h.i.j.example.org", "*.d.e.f.g.h.i.j.example.org", "*.c.d.e.f.g.h.i.j.example.org", "*.b.c.d.e.f.g.h.i.j.example.org", "*.l.example.org", "*.k.l.example.org"}, x509c.DNSNames) + + // Check that certificate is cached + tlsCert4, err := c.GetOrCreateCert("xyz.example.org") + x509c = tlsCert4.Leaf + assert.Nil(t, err) + assert.True(t, tlsCert3 == tlsCert4) + assert.Nil(t, x509c.VerifyHostname("example.org")) + assert.Nil(t, x509c.VerifyHostname("jkf.example.org")) + assert.Nil(t, x509c.VerifyHostname("n.j.example.org")) + assert.Nil(t, x509c.VerifyHostname("c.i.j.example.org")) + assert.Nil(t, x509c.VerifyHostname("m.l.example.org")) + assert.Error(t, x509c.VerifyHostname("m.l.jkf.example.org")) + + // Check the certificate for an IP + tlsCertForIP, err := c.GetOrCreateCert("192.168.0.1") x509c = tlsCertForIP.Leaf + assert.Nil(t, err) assert.Equal(t, 1, len(x509c.IPAddresses)) assert.True(t, net.ParseIP("192.168.0.1").Equal(x509c.IPAddresses[0])) + + // Check that certificate is cached + tlsCertForIP2, err := c.GetOrCreateCert("192.168.0.1") + x509c = tlsCertForIP2.Leaf + assert.Nil(t, err) + assert.True(t, tlsCertForIP == tlsCertForIP2) + assert.Nil(t, x509c.VerifyHostname("192.168.0.1")) } func TestGenerateAndSave(t *testing.T) { diff --git a/common/cert/storage.go b/common/cert/storage.go index 61663e73..a55d065c 100644 --- a/common/cert/storage.go +++ b/common/cert/storage.go @@ -2,31 +2,31 @@ package cert import ( "crypto/tls" - "time" - "github.com/Dreamacro/clash/common/cache" + "github.com/Dreamacro/clash/component/trie" ) -var TTL = time.Hour * 2 - -// AutoGCCertsStorage cache with the generated certificates, auto released after TTL -type AutoGCCertsStorage struct { - certsCache *cache.Cache[string, *tls.Certificate] +// DomainTrieCertsStorage cache wildcard certificates +type DomainTrieCertsStorage struct { + certsCache *trie.DomainTrie[*tls.Certificate] } // Get gets the certificate from the storage -func (c *AutoGCCertsStorage) Get(key string) (*tls.Certificate, bool) { - ca := c.certsCache.Get(key) - return ca, ca != nil +func (c *DomainTrieCertsStorage) Get(key string) (*tls.Certificate, bool) { + ca := c.certsCache.Search(key) + if ca == nil { + return nil, false + } + return ca.Data, true } // Set saves the certificate to the storage -func (c *AutoGCCertsStorage) Set(key string, cert *tls.Certificate) { - c.certsCache.Put(key, cert, TTL) +func (c *DomainTrieCertsStorage) Set(key string, cert *tls.Certificate) { + _ = c.certsCache.Insert(key, cert) } -func NewAutoGCCertsStorage() *AutoGCCertsStorage { - return &AutoGCCertsStorage{ - certsCache: cache.New[string, *tls.Certificate](TTL), +func NewDomainTrieCertsStorage() *DomainTrieCertsStorage { + return &DomainTrieCertsStorage{ + certsCache: trie.New[*tls.Certificate](), } } diff --git a/common/generics/list/list.go b/common/generics/list/list.go new file mode 100644 index 00000000..a06a7c61 --- /dev/null +++ b/common/generics/list/list.go @@ -0,0 +1,235 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package list implements a doubly linked list. +// +// To iterate over a list (where l is a *List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e.Value +// } +// +package list + +// Element is an element of a linked list. +type Element[T any] struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *Element[T] + + // The list to which this element belongs. + list *List[T] + + // The value stored with this element. + Value T +} + +// Next returns the next list element or nil. +func (e *Element[T]) Next() *Element[T] { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *Element[T]) Prev() *Element[T] { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// List represents a doubly linked list. +// The zero value for List is an empty list ready to use. +type List[T any] struct { + root Element[T] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *List[T]) Init() *List[T] { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// New returns an initialized list. +func New[T any]() *List[T] { return new(List[T]).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *List[T]) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *List[T]) Front() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *List[T]) Back() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *List[T]) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *List[T]) insert(e, at *Element[T]) *Element[T] { + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { + return l.insert(&Element[T]{Value: v}, at) +} + +// remove removes e from its list, decrements l.len +func (l *List[T]) remove(e *Element[T]) { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- +} + +// move moves e to next to at. +func (l *List[T]) move(e, at *Element[T]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *List[T]) Remove(e *Element[T]) T { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *List[T]) PushFront(v T) *Element[T] { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *List[T]) PushBack(v T) *Element[T] { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[T]) MoveToFront(e *Element[T]) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[T]) MoveToBack(e *Element[T]) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[T]) MoveBefore(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } + l.move(e, mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[T]) MoveAfter(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } + l.move(e, mark) +} + +// PushBackList inserts a copy of another list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *List[T]) PushBackList(other *List[T]) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of another list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *List[T]) PushFrontList(other *List[T]) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/common/net/relay.go b/common/net/relay.go new file mode 100644 index 00000000..6035a412 --- /dev/null +++ b/common/net/relay.go @@ -0,0 +1,39 @@ +package net + +import ( + "io" + "net" + "time" + + "github.com/Dreamacro/clash/common/pool" +) + +// Relay copies between left and right bidirectionally. +func Relay(leftConn, rightConn net.Conn) { + ch := make(chan error) + + tcpKeepAlive(leftConn) + tcpKeepAlive(rightConn) + + go func() { + buf := pool.Get(pool.RelayBufferSize) + // Wrapping to avoid using *net.TCPConn.(ReadFrom) + // See also https://github.com/Dreamacro/clash/pull/1209 + _, err := io.CopyBuffer(WriteOnlyWriter{Writer: leftConn}, ReadOnlyReader{Reader: rightConn}, buf) + _ = pool.Put(buf) + _ = leftConn.SetReadDeadline(time.Now()) + ch <- err + }() + + buf := pool.Get(pool.RelayBufferSize) + _, _ = io.CopyBuffer(WriteOnlyWriter{Writer: rightConn}, ReadOnlyReader{Reader: leftConn}, buf) + _ = pool.Put(buf) + _ = rightConn.SetReadDeadline(time.Now()) + <-ch +} + +func tcpKeepAlive(c net.Conn) { + if tcp, ok := c.(*net.TCPConn); ok { + _ = tcp.SetKeepAlive(true) + } +} diff --git a/common/observable/iterable.go b/common/observable/iterable.go index 2ac38b40..c78b49a3 100644 --- a/common/observable/iterable.go +++ b/common/observable/iterable.go @@ -1,3 +1,3 @@ package observable -type Iterable <-chan any +type Iterable[T any] <-chan T diff --git a/common/observable/observable.go b/common/observable/observable.go index 64bd0a0a..62b2e153 100644 --- a/common/observable/observable.go +++ b/common/observable/observable.go @@ -5,14 +5,14 @@ import ( "sync" ) -type Observable struct { - iterable Iterable - listener map[Subscription]*Subscriber +type Observable[T any] struct { + iterable Iterable[T] + listener map[Subscription[T]]*Subscriber[T] mux sync.Mutex done bool } -func (o *Observable) process() { +func (o *Observable[T]) process() { for item := range o.iterable { o.mux.Lock() for _, sub := range o.listener { @@ -23,7 +23,7 @@ func (o *Observable) process() { o.close() } -func (o *Observable) close() { +func (o *Observable[T]) close() { o.mux.Lock() defer o.mux.Unlock() @@ -33,18 +33,18 @@ func (o *Observable) close() { } } -func (o *Observable) Subscribe() (Subscription, error) { +func (o *Observable[T]) Subscribe() (Subscription[T], error) { o.mux.Lock() defer o.mux.Unlock() if o.done { - return nil, errors.New("Observable is closed") + return nil, errors.New("observable is closed") } - subscriber := newSubscriber() + subscriber := newSubscriber[T]() o.listener[subscriber.Out()] = subscriber return subscriber.Out(), nil } -func (o *Observable) UnSubscribe(sub Subscription) { +func (o *Observable[T]) UnSubscribe(sub Subscription[T]) { o.mux.Lock() defer o.mux.Unlock() subscriber, exist := o.listener[sub] @@ -55,10 +55,10 @@ func (o *Observable) UnSubscribe(sub Subscription) { subscriber.Close() } -func NewObservable(any Iterable) *Observable { - observable := &Observable{ - iterable: any, - listener: map[Subscription]*Subscriber{}, +func NewObservable[T any](iter Iterable[T]) *Observable[T] { + observable := &Observable[T]{ + iterable: iter, + listener: map[Subscription[T]]*Subscriber[T]{}, } go observable.process() return observable diff --git a/common/observable/observable_test.go b/common/observable/observable_test.go index da3e6d5c..5459e0e2 100644 --- a/common/observable/observable_test.go +++ b/common/observable/observable_test.go @@ -9,8 +9,8 @@ import ( "go.uber.org/atomic" ) -func iterator(item []any) chan any { - ch := make(chan any) +func iterator[T any](item []T) chan T { + ch := make(chan T) go func() { time.Sleep(100 * time.Millisecond) for _, elm := range item { @@ -22,8 +22,8 @@ func iterator(item []any) chan any { } func TestObservable(t *testing.T) { - iter := iterator([]any{1, 2, 3, 4, 5}) - src := NewObservable(iter) + iter := iterator[int]([]int{1, 2, 3, 4, 5}) + src := NewObservable[int](iter) data, err := src.Subscribe() assert.Nil(t, err) count := 0 @@ -34,15 +34,15 @@ func TestObservable(t *testing.T) { } func TestObservable_MultiSubscribe(t *testing.T) { - iter := iterator([]any{1, 2, 3, 4, 5}) - src := NewObservable(iter) + iter := iterator[int]([]int{1, 2, 3, 4, 5}) + src := NewObservable[int](iter) ch1, _ := src.Subscribe() ch2, _ := src.Subscribe() count := atomic.NewInt32(0) var wg sync.WaitGroup wg.Add(2) - waitCh := func(ch <-chan any) { + waitCh := func(ch <-chan int) { for range ch { count.Inc() } @@ -55,8 +55,8 @@ func TestObservable_MultiSubscribe(t *testing.T) { } func TestObservable_UnSubscribe(t *testing.T) { - iter := iterator([]any{1, 2, 3, 4, 5}) - src := NewObservable(iter) + iter := iterator[int]([]int{1, 2, 3, 4, 5}) + src := NewObservable[int](iter) data, err := src.Subscribe() assert.Nil(t, err) src.UnSubscribe(data) @@ -65,8 +65,8 @@ func TestObservable_UnSubscribe(t *testing.T) { } func TestObservable_SubscribeClosedSource(t *testing.T) { - iter := iterator([]any{1}) - src := NewObservable(iter) + iter := iterator[int]([]int{1}) + src := NewObservable[int](iter) data, _ := src.Subscribe() <-data @@ -75,18 +75,18 @@ func TestObservable_SubscribeClosedSource(t *testing.T) { } func TestObservable_UnSubscribeWithNotExistSubscription(t *testing.T) { - sub := Subscription(make(chan any)) - iter := iterator([]any{1}) - src := NewObservable(iter) + sub := Subscription[int](make(chan int)) + iter := iterator[int]([]int{1}) + src := NewObservable[int](iter) src.UnSubscribe(sub) } func TestObservable_SubscribeGoroutineLeak(t *testing.T) { - iter := iterator([]any{1, 2, 3, 4, 5}) - src := NewObservable(iter) + iter := iterator[int]([]int{1, 2, 3, 4, 5}) + src := NewObservable[int](iter) max := 100 - var list []Subscription + var list []Subscription[int] for i := 0; i < max; i++ { ch, _ := src.Subscribe() list = append(list, ch) @@ -94,7 +94,7 @@ func TestObservable_SubscribeGoroutineLeak(t *testing.T) { var wg sync.WaitGroup wg.Add(max) - waitCh := func(ch <-chan any) { + waitCh := func(ch <-chan int) { for range ch { } wg.Done() @@ -115,11 +115,11 @@ func TestObservable_SubscribeGoroutineLeak(t *testing.T) { } func Benchmark_Observable_1000(b *testing.B) { - ch := make(chan any) - o := NewObservable(ch) + ch := make(chan int) + o := NewObservable[int](ch) num := 1000 - subs := []Subscription{} + subs := []Subscription[int]{} for i := 0; i < num; i++ { sub, _ := o.Subscribe() subs = append(subs, sub) @@ -130,7 +130,7 @@ func Benchmark_Observable_1000(b *testing.B) { b.ResetTimer() for _, sub := range subs { - go func(s Subscription) { + go func(s Subscription[int]) { for range s { } wg.Done() diff --git a/common/observable/subscriber.go b/common/observable/subscriber.go index 0d8559bc..b7df4cae 100644 --- a/common/observable/subscriber.go +++ b/common/observable/subscriber.go @@ -4,30 +4,30 @@ import ( "sync" ) -type Subscription <-chan any +type Subscription[T any] <-chan T -type Subscriber struct { - buffer chan any +type Subscriber[T any] struct { + buffer chan T once sync.Once } -func (s *Subscriber) Emit(item any) { +func (s *Subscriber[T]) Emit(item T) { s.buffer <- item } -func (s *Subscriber) Out() Subscription { +func (s *Subscriber[T]) Out() Subscription[T] { return s.buffer } -func (s *Subscriber) Close() { +func (s *Subscriber[T]) Close() { s.once.Do(func() { close(s.buffer) }) } -func newSubscriber() *Subscriber { - sub := &Subscriber{ - buffer: make(chan any, 200), +func newSubscriber[T any]() *Subscriber[T] { + sub := &Subscriber[T]{ + buffer: make(chan T, 200), } return sub } diff --git a/common/picker/picker.go b/common/picker/picker.go index e701268a..97004460 100644 --- a/common/picker/picker.go +++ b/common/picker/picker.go @@ -9,7 +9,7 @@ import ( // Picker provides synchronization, and Context cancelation // for groups of goroutines working on subtasks of a common task. // Inspired by errGroup -type Picker struct { +type Picker[T any] struct { ctx context.Context cancel func() @@ -17,12 +17,12 @@ type Picker struct { once sync.Once errOnce sync.Once - result any + result T err error } -func newPicker(ctx context.Context, cancel func()) *Picker { - return &Picker{ +func newPicker[T any](ctx context.Context, cancel func()) *Picker[T] { + return &Picker[T]{ ctx: ctx, cancel: cancel, } @@ -30,20 +30,20 @@ func newPicker(ctx context.Context, cancel func()) *Picker { // WithContext returns a new Picker and an associated Context derived from ctx. // and cancel when first element return. -func WithContext(ctx context.Context) (*Picker, context.Context) { +func WithContext[T any](ctx context.Context) (*Picker[T], context.Context) { ctx, cancel := context.WithCancel(ctx) - return newPicker(ctx, cancel), ctx + return newPicker[T](ctx, cancel), ctx } // WithTimeout returns a new Picker and an associated Context derived from ctx with timeout. -func WithTimeout(ctx context.Context, timeout time.Duration) (*Picker, context.Context) { +func WithTimeout[T any](ctx context.Context, timeout time.Duration) (*Picker[T], context.Context) { ctx, cancel := context.WithTimeout(ctx, timeout) - return newPicker(ctx, cancel), ctx + return newPicker[T](ctx, cancel), ctx } // Wait blocks until all function calls from the Go method have returned, // then returns the first nil error result (if any) from them. -func (p *Picker) Wait() any { +func (p *Picker[T]) Wait() T { p.wg.Wait() if p.cancel != nil { p.cancel() @@ -52,13 +52,13 @@ func (p *Picker) Wait() any { } // Error return the first error (if all success return nil) -func (p *Picker) Error() error { +func (p *Picker[T]) Error() error { return p.err } // Go calls the given function in a new goroutine. // The first call to return a nil error cancels the group; its result will be returned by Wait. -func (p *Picker) Go(f func() (any, error)) { +func (p *Picker[T]) Go(f func() (T, error)) { p.wg.Add(1) go func() { diff --git a/common/picker/picker_test.go b/common/picker/picker_test.go index ca10499d..17b823cb 100644 --- a/common/picker/picker_test.go +++ b/common/picker/picker_test.go @@ -8,33 +8,38 @@ import ( "github.com/stretchr/testify/assert" ) -func sleepAndSend(ctx context.Context, delay int, input any) func() (any, error) { - return func() (any, error) { +func sleepAndSend[T any](ctx context.Context, delay int, input T) func() (T, error) { + return func() (T, error) { timer := time.NewTimer(time.Millisecond * time.Duration(delay)) select { case <-timer.C: return input, nil case <-ctx.Done(): - return nil, ctx.Err() + return getZero[T](), ctx.Err() } } } func TestPicker_Basic(t *testing.T) { - picker, ctx := WithContext(context.Background()) + picker, ctx := WithContext[int](context.Background()) picker.Go(sleepAndSend(ctx, 30, 2)) picker.Go(sleepAndSend(ctx, 20, 1)) number := picker.Wait() assert.NotNil(t, number) - assert.Equal(t, number.(int), 1) + assert.Equal(t, number, 1) } func TestPicker_Timeout(t *testing.T) { - picker, ctx := WithTimeout(context.Background(), time.Millisecond*5) + picker, ctx := WithTimeout[int](context.Background(), time.Millisecond*5) picker.Go(sleepAndSend(ctx, 20, 1)) number := picker.Wait() - assert.Nil(t, number) + assert.Equal(t, number, getZero[int]()) assert.NotNil(t, picker.Error()) } + +func getZero[T any]() T { + var result T + return result +} diff --git a/common/singledo/singledo.go b/common/singledo/singledo.go index f6ff35a9..c741fc82 100644 --- a/common/singledo/singledo.go +++ b/common/singledo/singledo.go @@ -5,28 +5,28 @@ import ( "time" ) -type call struct { +type call[T any] struct { wg sync.WaitGroup - val any + val T err error } -type Single struct { +type Single[T any] struct { mux sync.Mutex last time.Time wait time.Duration - call *call - result *Result + call *call[T] + result *Result[T] } -type Result struct { - Val any +type Result[T any] struct { + Val T Err error } // Do single.Do likes sync.singleFlight //lint:ignore ST1008 it likes sync.singleFlight -func (s *Single) Do(fn func() (any, error)) (v any, err error, shared bool) { +func (s *Single[T]) Do(fn func() (T, error)) (v T, err error, shared bool) { s.mux.Lock() now := time.Now() if now.Before(s.last.Add(s.wait)) { @@ -34,31 +34,31 @@ func (s *Single) Do(fn func() (any, error)) (v any, err error, shared bool) { return s.result.Val, s.result.Err, true } - if call := s.call; call != nil { + if callM := s.call; callM != nil { s.mux.Unlock() - call.wg.Wait() - return call.val, call.err, true + callM.wg.Wait() + return callM.val, callM.err, true } - call := &call{} - call.wg.Add(1) - s.call = call + callM := &call[T]{} + callM.wg.Add(1) + s.call = callM s.mux.Unlock() - call.val, call.err = fn() - call.wg.Done() + callM.val, callM.err = fn() + callM.wg.Done() s.mux.Lock() s.call = nil - s.result = &Result{call.val, call.err} + s.result = &Result[T]{callM.val, callM.err} s.last = now s.mux.Unlock() - return call.val, call.err, false + return callM.val, callM.err, false } -func (s *Single) Reset() { +func (s *Single[T]) Reset() { s.last = time.Time{} } -func NewSingle(wait time.Duration) *Single { - return &Single{wait: wait} +func NewSingle[T any](wait time.Duration) *Single[T] { + return &Single[T]{wait: wait} } diff --git a/common/singledo/singledo_test.go b/common/singledo/singledo_test.go index 71b6ac9f..26e3d37d 100644 --- a/common/singledo/singledo_test.go +++ b/common/singledo/singledo_test.go @@ -10,13 +10,13 @@ import ( ) func TestBasic(t *testing.T) { - single := NewSingle(time.Millisecond * 30) + single := NewSingle[int](time.Millisecond * 30) foo := 0 shardCount := atomic.NewInt32(0) - call := func() (any, error) { + call := func() (int, error) { foo++ time.Sleep(time.Millisecond * 5) - return nil, nil + return 0, nil } var wg sync.WaitGroup @@ -38,32 +38,32 @@ func TestBasic(t *testing.T) { } func TestTimer(t *testing.T) { - single := NewSingle(time.Millisecond * 30) + single := NewSingle[int](time.Millisecond * 30) foo := 0 - call := func() (any, error) { + callM := func() (int, error) { foo++ - return nil, nil + return 0, nil } - single.Do(call) + _, _, _ = single.Do(callM) time.Sleep(10 * time.Millisecond) - _, _, shard := single.Do(call) + _, _, shard := single.Do(callM) assert.Equal(t, 1, foo) assert.True(t, shard) } func TestReset(t *testing.T) { - single := NewSingle(time.Millisecond * 30) + single := NewSingle[int](time.Millisecond * 30) foo := 0 - call := func() (any, error) { + callM := func() (int, error) { foo++ - return nil, nil + return 0, nil } - single.Do(call) + _, _, _ = single.Do(callM) single.Reset() - single.Do(call) + _, _, _ = single.Do(callM) assert.Equal(t, 2, foo) } diff --git a/component/geodata/strmatcher/ac_automaton_matcher.go b/component/geodata/strmatcher/ac_automaton_matcher.go index ef0bc5d9..d134c68a 100644 --- a/component/geodata/strmatcher/ac_automaton_matcher.go +++ b/component/geodata/strmatcher/ac_automaton_matcher.go @@ -1,7 +1,7 @@ package strmatcher import ( - "container/list" + "github.com/Dreamacro/clash/common/generics/list" ) const validCharCount = 53 @@ -190,7 +190,7 @@ func (ac *ACAutomaton) Add(domain string, t Type) { } func (ac *ACAutomaton) Build() { - queue := list.New() + queue := list.New[Edge]() for i := 0; i < validCharCount; i++ { if ac.trie[0][i].nextNode != 0 { queue.PushBack(ac.trie[0][i]) @@ -201,7 +201,7 @@ func (ac *ACAutomaton) Build() { if front == nil { break } else { - node := front.Value.(Edge).nextNode + node := front.Value.nextNode queue.Remove(front) for i := 0; i < validCharCount; i++ { if ac.trie[node][i].nextNode != 0 { diff --git a/component/iface/iface.go b/component/iface/iface.go index 637d4876..11c754f8 100644 --- a/component/iface/iface.go +++ b/component/iface/iface.go @@ -21,10 +21,10 @@ var ( ErrAddrNotFound = errors.New("addr not found") ) -var interfaces = singledo.NewSingle(time.Second * 20) +var interfaces = singledo.NewSingle[map[string]*Interface](time.Second * 20) func ResolveInterface(name string) (*Interface, error) { - value, err, _ := interfaces.Do(func() (any, error) { + value, err, _ := interfaces.Do(func() (map[string]*Interface, error) { ifaces, err := net.Interfaces() if err != nil { return nil, err @@ -66,7 +66,7 @@ func ResolveInterface(name string) (*Interface, error) { return nil, err } - ifaces := value.(map[string]*Interface) + ifaces := value iface, ok := ifaces[name] if !ok { return nil, ErrIfaceNotFound diff --git a/component/pool/pool.go b/component/pool/pool.go index ef117539..f8173761 100644 --- a/component/pool/pool.go +++ b/component/pool/pool.go @@ -6,55 +6,55 @@ import ( "time" ) -type Factory = func(context.Context) (any, error) +type Factory[T any] func(context.Context) (T, error) -type entry struct { - elm any +type entry[T any] struct { + elm T time time.Time } -type Option func(*pool) +type Option[T any] func(*pool[T]) // WithEvict set the evict callback -func WithEvict(cb func(any)) Option { - return func(p *pool) { +func WithEvict[T any](cb func(T)) Option[T] { + return func(p *pool[T]) { p.evict = cb } } // WithAge defined element max age (millisecond) -func WithAge(maxAge int64) Option { - return func(p *pool) { +func WithAge[T any](maxAge int64) Option[T] { + return func(p *pool[T]) { p.maxAge = maxAge } } // WithSize defined max size of Pool -func WithSize(maxSize int) Option { - return func(p *pool) { - p.ch = make(chan any, maxSize) +func WithSize[T any](maxSize int) Option[T] { + return func(p *pool[T]) { + p.ch = make(chan *entry[T], maxSize) } } // Pool is for GC, see New for detail -type Pool struct { - *pool +type Pool[T any] struct { + *pool[T] } -type pool struct { - ch chan any - factory Factory - evict func(any) +type pool[T any] struct { + ch chan *entry[T] + factory Factory[T] + evict func(T) maxAge int64 } -func (p *pool) GetContext(ctx context.Context) (any, error) { +func (p *pool[T]) GetContext(ctx context.Context) (T, error) { now := time.Now() for { select { case item := <-p.ch: - elm := item.(*entry) - if p.maxAge != 0 && now.Sub(item.(*entry).time).Milliseconds() > p.maxAge { + elm := item + if p.maxAge != 0 && now.Sub(item.time).Milliseconds() > p.maxAge { if p.evict != nil { p.evict(elm.elm) } @@ -68,12 +68,12 @@ func (p *pool) GetContext(ctx context.Context) (any, error) { } } -func (p *pool) Get() (any, error) { +func (p *pool[T]) Get() (T, error) { return p.GetContext(context.Background()) } -func (p *pool) Put(item any) { - e := &entry{ +func (p *pool[T]) Put(item T) { + e := &entry[T]{ elm: item, time: time.Now(), } @@ -90,17 +90,17 @@ func (p *pool) Put(item any) { } } -func recycle(p *Pool) { +func recycle[T any](p *Pool[T]) { for item := range p.pool.ch { if p.pool.evict != nil { - p.pool.evict(item.(*entry).elm) + p.pool.evict(item.elm) } } } -func New(factory Factory, options ...Option) *Pool { - p := &pool{ - ch: make(chan any, 10), +func New[T any](factory Factory[T], options ...Option[T]) *Pool[T] { + p := &pool[T]{ + ch: make(chan *entry[T], 10), factory: factory, } @@ -108,7 +108,7 @@ func New(factory Factory, options ...Option) *Pool { option(p) } - P := &Pool{p} - runtime.SetFinalizer(P, recycle) + P := &Pool[T]{p} + runtime.SetFinalizer(P, recycle[T]) return P } diff --git a/component/pool/pool_test.go b/component/pool/pool_test.go index 5492f4c8..752aaace 100644 --- a/component/pool/pool_test.go +++ b/component/pool/pool_test.go @@ -8,9 +8,9 @@ import ( "github.com/stretchr/testify/assert" ) -func lg() Factory { +func lg() Factory[int] { initial := -1 - return func(context.Context) (any, error) { + return func(context.Context) (int, error) { initial++ return initial, nil } @@ -18,23 +18,23 @@ func lg() Factory { func TestPool_Basic(t *testing.T) { g := lg() - pool := New(g) + pool := New[int](g) elm, _ := pool.Get() - assert.Equal(t, 0, elm.(int)) + assert.Equal(t, 0, elm) pool.Put(elm) elm, _ = pool.Get() - assert.Equal(t, 0, elm.(int)) + assert.Equal(t, 0, elm) elm, _ = pool.Get() - assert.Equal(t, 1, elm.(int)) + assert.Equal(t, 1, elm) } func TestPool_MaxSize(t *testing.T) { g := lg() size := 5 - pool := New(g, WithSize(size)) + pool := New[int](g, WithSize[int](size)) - var items []any + var items []int for i := 0; i < size; i++ { item, _ := pool.Get() @@ -42,7 +42,7 @@ func TestPool_MaxSize(t *testing.T) { } extra, _ := pool.Get() - assert.Equal(t, size, extra.(int)) + assert.Equal(t, size, extra) for _, item := range items { pool.Put(item) @@ -52,22 +52,22 @@ func TestPool_MaxSize(t *testing.T) { for _, item := range items { elm, _ := pool.Get() - assert.Equal(t, item.(int), elm.(int)) + assert.Equal(t, item, elm) } } func TestPool_MaxAge(t *testing.T) { g := lg() - pool := New(g, WithAge(20)) + pool := New[int](g, WithAge[int](20)) elm, _ := pool.Get() pool.Put(elm) elm, _ = pool.Get() - assert.Equal(t, 0, elm.(int)) + assert.Equal(t, 0, elm) pool.Put(elm) time.Sleep(time.Millisecond * 22) elm, _ = pool.Get() - assert.Equal(t, 1, elm.(int)) + assert.Equal(t, 1, elm) } diff --git a/dns/resolver.go b/dns/resolver.go index accd9a8c..27d23d5c 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -152,10 +152,10 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M } func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { - fast, ctx := picker.WithTimeout(ctx, resolver.DefaultDNSTimeout) + fast, ctx := picker.WithTimeout[*D.Msg](ctx, resolver.DefaultDNSTimeout) for _, client := range clients { r := client - fast.Go(func() (any, error) { + fast.Go(func() (*D.Msg, error) { m, err := r.ExchangeContext(ctx, m) if err != nil { return nil, err @@ -175,7 +175,7 @@ func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D. return nil, err } - msg = elm.(*D.Msg) + msg = elm return } diff --git a/go.mod b/go.mod index 934ee95d..6ae679db 100644 --- a/go.mod +++ b/go.mod @@ -19,16 +19,16 @@ require ( go.uber.org/atomic v1.9.0 go.uber.org/automaxprocs v1.5.1 golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 - golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 + golang.org/x/net v0.0.0-20220421235706-1d1ef9303861 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - golang.org/x/sys v0.0.0-20220412211240-33da011f77ad + golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 golang.org/x/text v0.3.8-0.20220124021120-d1c84af989ab golang.org/x/time v0.0.0-20220411224347-583f2d630306 golang.zx2c4.com/wireguard v0.0.0-20220318042302-193cf8d6a5d6 golang.zx2c4.com/wireguard/windows v0.5.4-0.20220317000008-6432784c2469 google.golang.org/protobuf v1.28.0 gopkg.in/yaml.v2 v2.4.0 - gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b + gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4 ) require ( diff --git a/go.sum b/go.sum index bcfdfad0..74dcb991 100644 --- a/go.sum +++ b/go.sum @@ -97,8 +97,8 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 h1:6mzvA99KwZxbOrxww4EvWVQUnN1+xEu9tafK5ZxkYeA= -golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220421235706-1d1ef9303861 h1:yssD99+7tqHWO5Gwh81phT+67hg+KttniBr6UnEXOY8= +golang.org/x/net v0.0.0-20220421235706-1d1ef9303861/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -122,8 +122,8 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= -golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16CMAGuqwO2lX1mTyyRRc= +golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -159,5 +159,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b h1:zBJp2eKSoNIV6+9LO3bRhlnuK280Oyrwc6OeFIN6VzU= -gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI= +gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4 h1:CSkd548jw5hmVwdJ+JuUhMtRV56oQBER7sbkIOePP2Y= +gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI= diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 93babbe1..f302ce61 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -2,11 +2,9 @@ package executor import ( "fmt" - "net" "net/netip" "os" "runtime" - "strconv" "sync" "github.com/Dreamacro/clash/adapter" @@ -317,13 +315,7 @@ func updateIPTables(cfg *config.Config) { return } - _, dnsPortStr, err := net.SplitHostPort(dnsCfg.Listen) - if err != nil { - err = fmt.Errorf("DNS server must be enable") - return - } - - dnsPort, err := strconv.ParseUint(dnsPortStr, 10, 16) + dnsPort, err := netip.ParseAddrPort(dnsCfg.Listen) if err != nil { err = fmt.Errorf("DNS server must be enable") return @@ -337,7 +329,7 @@ func updateIPTables(cfg *config.Config) { dialer.DefaultRoutingMark.Store(2158) } - err = tproxy.SetTProxyIPTables(inboundInterface, uint16(tProxyPort), uint16(dnsPort)) + err = tproxy.SetTProxyIPTables(inboundInterface, uint16(tProxyPort), dnsPort.Port()) if err != nil { return } diff --git a/hub/route/server.go b/hub/route/server.go index 0b8cc84e..99575415 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -51,14 +51,14 @@ func Start(addr string, secret string) { r := chi.NewRouter() - cors := cors.New(cors.Options{ + corsM := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE"}, AllowedHeaders: []string{"Content-Type", "Authorization"}, MaxAge: 300, }) - r.Use(cors.Handler) + r.Use(corsM.Handler) r.Group(func(r chi.Router) { r.Use(authentication) @@ -209,24 +209,35 @@ func getLogs(w http.ResponseWriter, r *http.Request) { render.Status(r, http.StatusOK) } + ch := make(chan log.Event, 1024) sub := log.Subscribe() defer log.UnSubscribe(sub) buf := &bytes.Buffer{} - var err error - for elm := range sub { - buf.Reset() - log := elm.(*log.Event) - if log.LogLevel < level { + + go func() { + for elm := range sub { + select { + case ch <- elm: + default: + } + } + close(ch) + }() + + for logM := range ch { + if logM.LogLevel < level { continue } + buf.Reset() if err := json.NewEncoder(buf).Encode(Log{ - Type: log.Type(), - Payload: log.Payload, + Type: logM.Type(), + Payload: logM.Payload, }); err != nil { break } + var err error if wsConn == nil { _, err = w.Write(buf.Bytes()) w.(http.Flusher).Flush() diff --git a/listener/http/proxy.go b/listener/http/proxy.go index bd39b8b4..0ec43dc7 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -19,12 +19,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, client := newClient(c.RemoteAddr(), in) defer client.CloseIdleConnections() - var conn *N.BufferedConn - if bufConn, ok := c.(*N.BufferedConn); ok { - conn = bufConn - } else { - conn = N.NewBufferedConn(c) - } + conn := N.NewBufferedConn(c) keepAlive := true trusted := cache == nil // disable authenticate if cache is nil @@ -66,15 +61,23 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, request.RequestURI = "" - RemoveHopByHopHeaders(request.Header) - RemoveExtraHTTPHostPort(request) + if isUpgradeRequest(request) { + if resp = HandleUpgrade(conn, conn.RemoteAddr(), request, in); resp == nil { + return // hijack connection + } + } - if request.URL.Scheme == "" || request.URL.Host == "" { - resp = responseWith(request, http.StatusBadRequest) - } else { - resp, err = client.Do(request) - if err != nil { - resp = responseWith(request, http.StatusBadGateway) + if resp == nil { + RemoveHopByHopHeaders(request.Header) + RemoveExtraHTTPHostPort(request) + + if request.URL.Scheme == "" || request.URL.Host == "" { + resp = responseWith(request, http.StatusBadRequest) + } else { + resp, err = client.Do(request) + if err != nil { + resp = responseWith(request, http.StatusBadGateway) + } } } @@ -95,7 +98,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, } } - conn.Close() + _ = conn.Close() } func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response { diff --git a/listener/http/upgrade.go b/listener/http/upgrade.go new file mode 100644 index 00000000..7e53eecf --- /dev/null +++ b/listener/http/upgrade.go @@ -0,0 +1,96 @@ +package http + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "strings" + "time" + + "github.com/Dreamacro/clash/adapter/inbound" + N "github.com/Dreamacro/clash/common/net" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/transport/socks5" +) + +func isUpgradeRequest(req *http.Request) bool { + return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") +} + +func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { + removeProxyHeaders(request.Header) + RemoveExtraHTTPHostPort(request) + + address := request.Host + if _, _, err := net.SplitHostPort(address); err != nil { + port := "80" + if request.TLS != nil { + port = "443" + } + address = net.JoinHostPort(address, port) + } + + dstAddr := socks5.ParseAddr(address) + if dstAddr == nil { + return + } + + left, right := net.Pipe() + + in <- inbound.NewMitm(dstAddr, source, request.Header.Get("User-Agent"), right) + + var remoteServer *N.BufferedConn + if request.TLS != nil { + tlsConn := tls.Client(left, &tls.Config{ + ServerName: request.URL.Hostname(), + }) + + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + defer cancel() + if tlsConn.HandshakeContext(ctx) != nil { + _ = localConn.Close() + _ = left.Close() + return + } + + remoteServer = N.NewBufferedConn(tlsConn) + } else { + remoteServer = N.NewBufferedConn(left) + } + + defer func() { + _ = remoteServer.Close() + }() + + err := request.Write(remoteServer) + if err != nil { + _ = localConn.Close() + return + } + + resp, err = http.ReadResponse(remoteServer.Reader(), request) + if err != nil { + _ = localConn.Close() + return + } + + if resp.StatusCode == http.StatusSwitchingProtocols { + removeProxyHeaders(resp.Header) + + err = localConn.SetReadDeadline(time.Time{}) // set to not time out + if err != nil { + return + } + + err = resp.Write(localConn) + if err != nil { + return + } + + N.Relay(remoteServer, localConn) // blocking here + _ = localConn.Close() + resp = nil + } + return +} diff --git a/listener/http/utils.go b/listener/http/utils.go index 94308f19..e9994acc 100644 --- a/listener/http/utils.go +++ b/listener/http/utils.go @@ -8,15 +8,21 @@ import ( "strings" ) +// removeProxyHeaders remove Proxy-* headers +func removeProxyHeaders(header http.Header) { + header.Del("Proxy-Connection") + header.Del("Proxy-Authenticate") + header.Del("Proxy-Authorization") +} + // RemoveHopByHopHeaders remove hop-by-hop header func RemoveHopByHopHeaders(header http.Header) { // Strip hop-by-hop header based on RFC: // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1 // https://www.mnot.net/blog/2011/07/11/what_proxies_must_do - header.Del("Proxy-Connection") - header.Del("Proxy-Authenticate") - header.Del("Proxy-Authorization") + removeProxyHeaders(header) + header.Del("TE") header.Del("Trailers") header.Del("Transfer-Encoding") diff --git a/listener/listener.go b/listener/listener.go index 66bf044d..34c43f64 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -397,13 +397,12 @@ func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) { certOption, err = cert.NewConfig( x509c, privateKey, - cert.NewAutoGCCertsStorage(), ) if err != nil { return } - certOption.SetValidity(time.Hour * 24 * 90) + certOption.SetValidity(time.Hour * 24 * 365 * 2) // 2 years certOption.SetOrganization("Clash ManInTheMiddle Proxy Services") opt := &mitm.Option{ diff --git a/listener/mitm/client.go b/listener/mitm/client.go index b20d8586..a01c65d8 100644 --- a/listener/mitm/client.go +++ b/listener/mitm/client.go @@ -18,9 +18,11 @@ func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http Transport: &http.Transport{ // excepted HTTP/2 TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), - // from http.DefaultTransport - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, + // only needed 1 connection + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + IdleConnTimeout: 60 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: &tls.Config{ diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index 98323e9a..09aecbb7 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -44,13 +44,13 @@ startOver: readLoop: for { // use SetReadDeadline instead of Proxy-Connection keep-alive - if err := conn.SetReadDeadline(time.Now().Add(95 * time.Second)); err != nil { - break readLoop + if err := conn.SetReadDeadline(time.Now().Add(65 * time.Second)); err != nil { + break } request, err := H.ReadRequest(conn.Reader()) if err != nil { - break readLoop + break } var response *http.Response @@ -71,7 +71,7 @@ readLoop: // Manual writing to support CONNECT for http 1.0 (workaround for uplay client) if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil { handleError(opt, session, err) - break readLoop // close connection + break // close connection } if strings.HasSuffix(session.request.URL.Host, ":80") { @@ -81,18 +81,18 @@ readLoop: b, err := conn.Peek(1) if err != nil { handleError(opt, session, err) - break readLoop // close connection + break // close connection } // TLS handshake. if b[0] == 0x16 { - tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) + tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Hostname())) // Handshake with the local client if err = tlsConn.Handshake(); err != nil { session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err)) _ = writeResponse(session, false) - break readLoop // close connection + break // close connection } c = tlsConn @@ -105,20 +105,27 @@ readLoop: prepareRequest(c, session.request) - H.RemoveHopByHopHeaders(session.request.Header) - H.RemoveExtraHTTPHostPort(session.request) - // hijack api if session.request.URL.Hostname() == opt.ApiHost { if err = handleApiRequest(session, opt); err != nil { handleError(opt, session, err) - break readLoop } - return + break } + // forward websocket + if isWebsocketRequest(request) { + session.request.RequestURI = "" + if session.response = H.HandleUpgrade(conn, source, request, in); session.response == nil { + return // hijack connection + } + } + + H.RemoveHopByHopHeaders(session.request.Header) + H.RemoveExtraHTTPHostPort(session.request) + // hijack custom request and write back custom response if necessary - if opt.Handler != nil { + if opt.Handler != nil && session.response == nil { newReq, newRes := opt.Handler.HandleRequest(session) if newReq != nil { session.request = newReq @@ -128,28 +135,30 @@ readLoop: if err = writeResponse(session, false); err != nil { handleError(opt, session, err) - break readLoop + break } - return + continue } } - session.request.RequestURI = "" + if session.response == nil { + session.request.RequestURI = "" - if session.request.URL.Host == "" { - session.response = session.NewErrorResponse(ErrInvalidURL) - } else { - client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) + if session.request.URL.Host == "" { + session.response = session.NewErrorResponse(ErrInvalidURL) + } else { + client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) - // send the request to remote server - session.response, err = client.Do(session.request) + // send the request to remote server + session.response, err = client.Do(session.request) - if err != nil { - handleError(opt, session, err) - session.response = session.NewErrorResponse(err) - if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { - _ = writeResponse(session, false) - break readLoop + if err != nil { + handleError(opt, session, err) + session.response = session.NewErrorResponse(fmt.Errorf("request failed: %w", err)) + if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { + _ = writeResponse(session, false) + break + } } } } @@ -157,7 +166,7 @@ readLoop: if err = writeResponseWithHandler(session, opt); err != nil { handleError(opt, session, err) - break readLoop // close connection + break // close connection } } @@ -167,13 +176,7 @@ readLoop: func writeResponseWithHandler(session *Session, opt *Option) error { if opt.Handler != nil { res := opt.Handler.HandleResponse(session) - if res != nil { - body := res.Body - defer func(body io.ReadCloser) { - _ = body.Close() - }(body) - session.response = res } } @@ -186,7 +189,7 @@ func writeResponse(session *Session, keepAlive bool) error { if keepAlive { session.response.Header.Set("Connection", "keep-alive") - session.response.Header.Set("Keep-Alive", "timeout=90") + session.response.Header.Set("Keep-Alive", "timeout=60") } return session.writeResponse() @@ -201,10 +204,6 @@ func handleApiRequest(session *Session, opt *Option) error { session.response = session.NewResponse(http.StatusOK, bytes.NewReader(b)) - defer func(body io.ReadCloser) { - _ = body.Close() - }(session.response.Body) - session.response.Close = true session.response.Header.Set("Content-Type", "application/x-x509-ca-cert") session.response.ContentLength = int64(len(b)) @@ -230,11 +229,6 @@ func handleApiRequest(session *Session, opt *Option) error { b = fmt.Sprintf(b, session.request.URL.Path) session.response = session.NewResponse(http.StatusNotFound, bytes.NewReader([]byte(b))) - - defer func(body io.ReadCloser) { - _ = body.Close() - }(session.response.Body) - session.response.Close = true session.response.Header.Set("Content-Type", "text/html;charset=utf-8") session.response.ContentLength = int64(len(b)) @@ -243,6 +237,12 @@ func handleApiRequest(session *Session, opt *Option) error { } func handleError(opt *Option, session *Session, err error) { + if session.response != nil { + defer func() { + _, _ = io.Copy(io.Discard, session.response.Body) + _ = session.response.Body.Close() + }() + } if opt.Handler != nil { opt.Handler.HandleError(session, err) } diff --git a/listener/mitm/session.go b/listener/mitm/session.go index c2622a69..99979a98 100644 --- a/listener/mitm/session.go +++ b/listener/mitm/session.go @@ -43,6 +43,9 @@ func (s *Session) writeResponse() error { if s.response == nil { return ErrInvalidResponse } + defer func(resp *http.Response) { + _ = resp.Body.Close() + }(s.response) return s.response.Write(s.conn) } diff --git a/listener/mitm/utils.go b/listener/mitm/utils.go index a84c75cf..d7c10a2a 100644 --- a/listener/mitm/utils.go +++ b/listener/mitm/utils.go @@ -20,6 +20,10 @@ var ( ErrInvalidURL = errors.New("invalid URL") ) +func isWebsocketRequest(req *http.Request) bool { + return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket" +} + func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { if body == nil { body = &bytes.Buffer{} diff --git a/listener/tproxy/tproxy_linux_iptables.go b/listener/tproxy/tproxy_iptables.go similarity index 100% rename from listener/tproxy/tproxy_linux_iptables.go rename to listener/tproxy/tproxy_iptables.go diff --git a/listener/tun/ipstack/system/mars/mars.go b/listener/tun/ipstack/system/mars/mars.go index e150437e..a553c2d6 100644 --- a/listener/tun/ipstack/system/mars/mars.go +++ b/listener/tun/ipstack/system/mars/mars.go @@ -27,10 +27,8 @@ func StartListener(device io.ReadWriteCloser, gateway, portal, broadcast netip.A } func (t *StackListener) Close() error { - _ = t.tcp.Close() _ = t.udp.Close() - - return t.device.Close() + return t.tcp.Close() } func (t *StackListener) TCP() *nat.TCP { diff --git a/listener/tun/ipstack/system/mars/nat/table.go b/listener/tun/ipstack/system/mars/nat/table.go index e0d86ccc..38b7d6c6 100644 --- a/listener/tun/ipstack/system/mars/nat/table.go +++ b/listener/tun/ipstack/system/mars/nat/table.go @@ -1,13 +1,14 @@ package nat import ( - "container/list" "net/netip" + + "github.com/Dreamacro/clash/common/generics/list" ) const ( portBegin = 30000 - portLength = 4096 + portLength = 10240 ) var zeroTuple = tuple{} @@ -23,9 +24,9 @@ type binding struct { } type table struct { - tuples map[tuple]*list.Element - ports [portLength]*list.Element - available *list.List + tuples map[tuple]*list.Element[*binding] + ports [portLength]*list.Element[*binding] + available *list.List[*binding] } func (t *table) tupleOf(port uint16) tuple { @@ -38,7 +39,7 @@ func (t *table) tupleOf(port uint16) tuple { t.available.MoveToFront(elm) - return elm.Value.(*binding).tuple + return elm.Value.tuple } func (t *table) portOf(tuple tuple) uint16 { @@ -49,12 +50,12 @@ func (t *table) portOf(tuple tuple) uint16 { t.available.MoveToFront(elm) - return portBegin + elm.Value.(*binding).offset + return portBegin + elm.Value.offset } func (t *table) newConn(tuple tuple) uint16 { elm := t.available.Back() - b := elm.Value.(*binding) + b := elm.Value delete(t.tuples, b.tuple) t.tuples[tuple] = elm @@ -67,9 +68,9 @@ func (t *table) newConn(tuple tuple) uint16 { func newTable() *table { result := &table{ - tuples: make(map[tuple]*list.Element, portLength), - ports: [portLength]*list.Element{}, - available: list.New(), + tuples: make(map[tuple]*list.Element[*binding], portLength), + ports: [portLength]*list.Element[*binding]{}, + available: list.New[*binding](), } for idx := range result.ports { diff --git a/listener/tun/ipstack/system/stack.go b/listener/tun/ipstack/system/stack.go index 92751d36..803e5db0 100644 --- a/listener/tun/ipstack/system/stack.go +++ b/listener/tun/ipstack/system/stack.go @@ -7,6 +7,7 @@ import ( "net/netip" "runtime" "strconv" + "sync" "time" "github.com/Dreamacro/clash/adapter/inbound" @@ -28,6 +29,8 @@ type sysStack struct { device device.Device closed bool + once sync.Once + wg sync.WaitGroup } func (s *sysStack) Close() error { @@ -38,10 +41,12 @@ func (s *sysStack) Close() error { }() s.closed = true - if s.stack != nil { - return s.stack.Close() - } - return nil + + err := s.stack.Close() + + s.wg.Wait() + + return err } func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) { @@ -67,16 +72,10 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref _ = tcp.Close() }(stack.TCP()) - defer log.Debugln("TCP: closed") - for !ipStack.closed { - if err = stack.TCP().SetDeadline(time.Time{}); err != nil { - break - } - conn, err := stack.TCP().Accept() if err != nil { - log.Debugln("Accept connection: %v", err) + log.Debugln("[STACK] accept connection error: %v", err) continue } @@ -146,6 +145,8 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref tcpIn <- context.NewConnContext(conn, metadata) } + + ipStack.wg.Done() } udp := func() { @@ -153,14 +154,13 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref _ = udp.Close() }(stack.UDP()) - defer log.Debugln("UDP: closed") - for !ipStack.closed { buf := pool.Get(pool.UDPBufferSize) n, lRAddr, rRAddr, err := stack.UDP().ReadFrom(buf) if err != nil { - return + _ = pool.Put(buf) + break } raw := buf[:n] @@ -209,17 +209,23 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref default: } } + + ipStack.wg.Done() } - go tcp() + ipStack.once.Do(func() { + ipStack.wg.Add(1) + go tcp() - numUDPWorkers := 4 - if num := runtime.GOMAXPROCS(0); num > numUDPWorkers { - numUDPWorkers = num - } - for i := 0; i < numUDPWorkers; i++ { - go udp() - } + numUDPWorkers := 4 + if num := runtime.GOMAXPROCS(0); num > numUDPWorkers { + numUDPWorkers = num + } + for i := 0; i < numUDPWorkers; i++ { + ipStack.wg.Add(1) + go udp() + } + }) return ipStack, nil } diff --git a/listener/tun/tun_adapter.go b/listener/tun/tun_adapter.go index 7461a492..98b0467c 100644 --- a/listener/tun/tun_adapter.go +++ b/listener/tun/tun_adapter.go @@ -145,6 +145,7 @@ func setAtLatest(stackType C.TUNStack, devName string) { case "darwin": // _, _ = cmd.ExecCmd("sysctl -w net.inet.ip.forwarding=1") // _, _ = cmd.ExecCmd("sysctl -w net.inet6.ip6.forwarding=1") + _, _ = cmd.ExecCmd("sudo launchctl limit maxfiles 10240 unlimited") case "windows": _, _ = cmd.ExecCmd("ipconfig /renew") case "linux": diff --git a/log/log.go b/log/log.go index 3a7ea729..be181fdd 100644 --- a/log/log.go +++ b/log/log.go @@ -10,8 +10,8 @@ import ( ) var ( - logCh = make(chan any) - source = observable.NewObservable(logCh) + logCh = make(chan Event) + source = observable.NewObservable[Event](logCh) level = INFO ) @@ -25,7 +25,7 @@ type Event struct { Payload string } -func (e *Event) Type() string { +func (e Event) Type() string { return e.LogLevel.String() } @@ -57,12 +57,12 @@ func Fatalln(format string, v ...any) { log.Fatalf(format, v...) } -func Subscribe() observable.Subscription { +func Subscribe() observable.Subscription[Event] { sub, _ := source.Subscribe() return sub } -func UnSubscribe(sub observable.Subscription) { +func UnSubscribe(sub observable.Subscription[Event]) { source.UnSubscribe(sub) } @@ -74,7 +74,7 @@ func SetLevel(newLevel LogLevel) { level = newLevel } -func print(data *Event) { +func print(data Event) { if data.LogLevel < level { return } @@ -91,8 +91,8 @@ func print(data *Event) { } } -func newLog(logLevel LogLevel, format string, v ...any) *Event { - return &Event{ +func newLog(logLevel LogLevel, format string, v ...any) Event { + return Event{ LogLevel: logLevel, Payload: fmt.Sprintf(format, v...), } diff --git a/test/go.mod b/test/go.mod index a24e2c63..3733d3ac 100644 --- a/test/go.mod +++ b/test/go.mod @@ -8,7 +8,7 @@ require ( github.com/docker/go-connections v0.4.0 github.com/miekg/dns v1.1.48 github.com/stretchr/testify v1.7.1 - golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 + golang.org/x/net v0.0.0-20220421235706-1d1ef9303861 ) replace github.com/Dreamacro/clash => ../ @@ -42,7 +42,7 @@ require ( golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect - golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect + golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 // indirect golang.org/x/text v0.3.8-0.20220124021120-d1c84af989ab // indirect golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect golang.org/x/tools v0.1.9 // indirect @@ -56,5 +56,5 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect gotest.tools/v3 v3.1.0 // indirect - gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b // indirect + gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4 // indirect ) diff --git a/test/go.sum b/test/go.sum index 1c7952fa..dbf6ed07 100644 --- a/test/go.sum +++ b/test/go.sum @@ -1012,8 +1012,8 @@ golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210825183410-e898025ed96a/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 h1:6mzvA99KwZxbOrxww4EvWVQUnN1+xEu9tafK5ZxkYeA= -golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220421235706-1d1ef9303861 h1:yssD99+7tqHWO5Gwh81phT+67hg+KttniBr6UnEXOY8= +golang.org/x/net v0.0.0-20220421235706-1d1ef9303861/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1143,8 +1143,8 @@ golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= -golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16CMAGuqwO2lX1mTyyRRc= +golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -1413,8 +1413,8 @@ gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= gotest.tools/v3 v3.1.0 h1:rVV8Tcg/8jHUkPUorwjaMTtemIMVXfIPKiOqnhEhakk= gotest.tools/v3 v3.1.0/go.mod h1:fHy7eyTmJFO5bQbUsEGQ1v4m2J3Jz9eWL54TP2/ZuYQ= -gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b h1:zBJp2eKSoNIV6+9LO3bRhlnuK280Oyrwc6OeFIN6VzU= -gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI= +gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4 h1:CSkd548jw5hmVwdJ+JuUhMtRV56oQBER7sbkIOePP2Y= +gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/transport/snell/pool.go b/transport/snell/pool.go index 62d21b4e..237baf21 100644 --- a/transport/snell/pool.go +++ b/transport/snell/pool.go @@ -11,7 +11,7 @@ import ( ) type Pool struct { - pool *pool.Pool + pool *pool.Pool[*Snell] } func (p *Pool) Get() (net.Conn, error) { @@ -24,12 +24,12 @@ func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) { return nil, err } - return &PoolConn{elm.(*Snell), p}, nil + return &PoolConn{elm, p}, nil } -func (p *Pool) Put(conn net.Conn) { +func (p *Pool) Put(conn *Snell) { if err := HalfClose(conn); err != nil { - conn.Close() + _ = conn.Close() return } @@ -64,22 +64,22 @@ func (pc *PoolConn) Write(b []byte) (int, error) { func (pc *PoolConn) Close() error { // clash use SetReadDeadline to break bidirectional copy between client and server. // reset it before reuse connection to avoid io timeout error. - pc.Snell.Conn.SetReadDeadline(time.Time{}) + _ = pc.Snell.Conn.SetReadDeadline(time.Time{}) pc.pool.Put(pc.Snell) return nil } func NewPool(factory func(context.Context) (*Snell, error)) *Pool { - p := pool.New( - func(ctx context.Context) (any, error) { + p := pool.New[*Snell]( + func(ctx context.Context) (*Snell, error) { return factory(ctx) }, - pool.WithAge(15000), - pool.WithSize(10), - pool.WithEvict(func(item any) { - item.(*Snell).Close() + pool.WithAge[*Snell](15000), + pool.WithSize[*Snell](10), + pool.WithEvict[*Snell](func(item *Snell) { + _ = item.Close() }), ) - return &Pool{p} + return &Pool{pool: p} } diff --git a/tunnel/connection.go b/tunnel/connection.go index 82443c35..0384e805 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -2,7 +2,6 @@ package tunnel import ( "errors" - "io" "net" "time" @@ -63,35 +62,5 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, fAddr n } func handleSocket(ctx C.ConnContext, outbound net.Conn) { - relay(ctx.Conn(), outbound) -} - -// relay copies between left and right bidirectionally. -func relay(leftConn, rightConn net.Conn) { - ch := make(chan error) - - tcpKeepAlive(leftConn) - tcpKeepAlive(rightConn) - - go func() { - buf := pool.Get(pool.RelayBufferSize) - // Wrapping to avoid using *net.TCPConn.(ReadFrom) - // See also https://github.com/Dreamacro/clash/pull/1209 - _, err := io.CopyBuffer(N.WriteOnlyWriter{Writer: leftConn}, N.ReadOnlyReader{Reader: rightConn}, buf) - pool.Put(buf) - leftConn.SetReadDeadline(time.Now()) - ch <- err - }() - - buf := pool.Get(pool.RelayBufferSize) - io.CopyBuffer(N.WriteOnlyWriter{Writer: rightConn}, N.ReadOnlyReader{Reader: leftConn}, buf) - pool.Put(buf) - rightConn.SetReadDeadline(time.Now()) - <-ch -} - -func tcpKeepAlive(c net.Conn) { - if tcp, ok := c.(*net.TCPConn); ok { - tcp.SetKeepAlive(true) - } + N.Relay(ctx.Conn(), outbound) }