Chore: merge branch 'with-tun' into plus-pro

This commit is contained in:
yaling888 2022-04-27 05:49:45 +08:00
commit 1e7cbd6358
50 changed files with 903 additions and 483 deletions

View File

@ -14,7 +14,7 @@ import (
type Fallback struct { type Fallback struct {
*outbound.Base *outbound.Base
disableUDP bool disableUDP bool
single *singledo.Single single *singledo.Single[[]C.Proxy]
providers []provider.ProxyProvider providers []provider.ProxyProvider
} }
@ -73,11 +73,11 @@ func (f *Fallback) Unwrap(metadata *C.Metadata) C.Proxy {
} }
func (f *Fallback) proxies(touch bool) []C.Proxy { func (f *Fallback) proxies(touch bool) []C.Proxy {
elm, _, _ := f.single.Do(func() (any, error) { elm, _, _ := f.single.Do(func() ([]C.Proxy, error) {
return getProvidersProxies(f.providers, touch), nil return getProvidersProxies(f.providers, touch), nil
}) })
return elm.([]C.Proxy) return elm
} }
func (f *Fallback) findAliveProxy(touch bool) C.Proxy { func (f *Fallback) findAliveProxy(touch bool) C.Proxy {
@ -99,7 +99,7 @@ func NewFallback(option *GroupCommonOption, providers []provider.ProxyProvider)
Interface: option.Interface, Interface: option.Interface,
RoutingMark: option.RoutingMark, RoutingMark: option.RoutingMark,
}), }),
single: singledo.NewSingle(defaultGetProxiesDuration), single: singledo.NewSingle[[]C.Proxy](defaultGetProxiesDuration),
providers: providers, providers: providers,
disableUDP: option.DisableUDP, disableUDP: option.DisableUDP,
} }

View File

@ -22,7 +22,7 @@ type strategyFn = func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy
type LoadBalance struct { type LoadBalance struct {
*outbound.Base *outbound.Base
disableUDP bool disableUDP bool
single *singledo.Single single *singledo.Single[[]C.Proxy]
providers []provider.ProxyProvider providers []provider.ProxyProvider
strategyFn strategyFn strategyFn strategyFn
} }
@ -140,11 +140,11 @@ func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy {
} }
func (lb *LoadBalance) proxies(touch bool) []C.Proxy { func (lb *LoadBalance) proxies(touch bool) []C.Proxy {
elm, _, _ := lb.single.Do(func() (any, error) { elm, _, _ := lb.single.Do(func() ([]C.Proxy, error) {
return getProvidersProxies(lb.providers, touch), nil return getProvidersProxies(lb.providers, touch), nil
}) })
return elm.([]C.Proxy) return elm
} }
// MarshalJSON implements C.ProxyAdapter // MarshalJSON implements C.ProxyAdapter
@ -176,7 +176,7 @@ func NewLoadBalance(option *GroupCommonOption, providers []provider.ProxyProvide
Interface: option.Interface, Interface: option.Interface,
RoutingMark: option.RoutingMark, RoutingMark: option.RoutingMark,
}), }),
single: singledo.NewSingle(defaultGetProxiesDuration), single: singledo.NewSingle[[]C.Proxy](defaultGetProxiesDuration),
providers: providers, providers: providers,
strategyFn: strategyFn, strategyFn: strategyFn,
disableUDP: option.DisableUDP, disableUDP: option.DisableUDP,

View File

@ -14,7 +14,7 @@ import (
type Relay struct { type Relay struct {
*outbound.Base *outbound.Base
single *singledo.Single single *singledo.Single[[]C.Proxy]
providers []provider.ProxyProvider providers []provider.ProxyProvider
} }
@ -79,11 +79,11 @@ func (r *Relay) MarshalJSON() ([]byte, error) {
} }
func (r *Relay) rawProxies(touch bool) []C.Proxy { func (r *Relay) rawProxies(touch bool) []C.Proxy {
elm, _, _ := r.single.Do(func() (any, error) { elm, _, _ := r.single.Do(func() ([]C.Proxy, error) {
return getProvidersProxies(r.providers, touch), nil return getProvidersProxies(r.providers, touch), nil
}) })
return elm.([]C.Proxy) return elm
} }
func (r *Relay) proxies(metadata *C.Metadata, touch bool) []C.Proxy { func (r *Relay) proxies(metadata *C.Metadata, touch bool) []C.Proxy {
@ -108,7 +108,7 @@ func NewRelay(option *GroupCommonOption, providers []provider.ProxyProvider) *Re
Interface: option.Interface, Interface: option.Interface,
RoutingMark: option.RoutingMark, RoutingMark: option.RoutingMark,
}), }),
single: singledo.NewSingle(defaultGetProxiesDuration), single: singledo.NewSingle[[]C.Proxy](defaultGetProxiesDuration),
providers: providers, providers: providers,
} }
} }

View File

@ -15,7 +15,7 @@ import (
type Selector struct { type Selector struct {
*outbound.Base *outbound.Base
disableUDP bool disableUDP bool
single *singledo.Single single *singledo.Single[C.Proxy]
selected string selected string
providers []provider.ProxyProvider providers []provider.ProxyProvider
} }
@ -83,7 +83,7 @@ func (s *Selector) Unwrap(metadata *C.Metadata) C.Proxy {
} }
func (s *Selector) selectedProxy(touch bool) C.Proxy { func (s *Selector) selectedProxy(touch bool) C.Proxy {
elm, _, _ := s.single.Do(func() (any, error) { elm, _, _ := s.single.Do(func() (C.Proxy, error) {
proxies := getProvidersProxies(s.providers, touch) proxies := getProvidersProxies(s.providers, touch)
for _, proxy := range proxies { for _, proxy := range proxies {
if proxy.Name() == s.selected { if proxy.Name() == s.selected {
@ -94,7 +94,7 @@ func (s *Selector) selectedProxy(touch bool) C.Proxy {
return proxies[0], nil return proxies[0], nil
}) })
return elm.(C.Proxy) return elm
} }
func NewSelector(option *GroupCommonOption, providers []provider.ProxyProvider) *Selector { func NewSelector(option *GroupCommonOption, providers []provider.ProxyProvider) *Selector {
@ -106,7 +106,7 @@ func NewSelector(option *GroupCommonOption, providers []provider.ProxyProvider)
Interface: option.Interface, Interface: option.Interface,
RoutingMark: option.RoutingMark, RoutingMark: option.RoutingMark,
}), }),
single: singledo.NewSingle(defaultGetProxiesDuration), single: singledo.NewSingle[C.Proxy](defaultGetProxiesDuration),
providers: providers, providers: providers,
selected: selected, selected: selected,
disableUDP: option.DisableUDP, disableUDP: option.DisableUDP,

View File

@ -25,8 +25,8 @@ type URLTest struct {
tolerance uint16 tolerance uint16
disableUDP bool disableUDP bool
fastNode C.Proxy fastNode C.Proxy
single *singledo.Single single *singledo.Single[[]C.Proxy]
fastSingle *singledo.Single fastSingle *singledo.Single[C.Proxy]
providers []provider.ProxyProvider providers []provider.ProxyProvider
} }
@ -58,15 +58,15 @@ func (u *URLTest) Unwrap(metadata *C.Metadata) C.Proxy {
} }
func (u *URLTest) proxies(touch bool) []C.Proxy { func (u *URLTest) proxies(touch bool) []C.Proxy {
elm, _, _ := u.single.Do(func() (any, error) { elm, _, _ := u.single.Do(func() ([]C.Proxy, error) {
return getProvidersProxies(u.providers, touch), nil return getProvidersProxies(u.providers, touch), nil
}) })
return elm.([]C.Proxy) return elm
} }
func (u *URLTest) fast(touch bool) C.Proxy { func (u *URLTest) fast(touch bool) C.Proxy {
elm, _, _ := u.fastSingle.Do(func() (any, error) { elm, _, _ := u.fastSingle.Do(func() (C.Proxy, error) {
proxies := u.proxies(touch) proxies := u.proxies(touch)
fast := proxies[0] fast := proxies[0]
min := fast.LastDelay() min := fast.LastDelay()
@ -96,7 +96,7 @@ func (u *URLTest) fast(touch bool) C.Proxy {
return u.fastNode, nil return u.fastNode, nil
}) })
return elm.(C.Proxy) return elm
} }
// SupportUDP implements C.ProxyAdapter // SupportUDP implements C.ProxyAdapter
@ -142,8 +142,8 @@ func NewURLTest(option *GroupCommonOption, providers []provider.ProxyProvider, o
Interface: option.Interface, Interface: option.Interface,
RoutingMark: option.RoutingMark, RoutingMark: option.RoutingMark,
}), }),
single: singledo.NewSingle(defaultGetProxiesDuration), single: singledo.NewSingle[[]C.Proxy](defaultGetProxiesDuration),
fastSingle: singledo.NewSingle(time.Second * 10), fastSingle: singledo.NewSingle[C.Proxy](time.Second * 10),
providers: providers, providers: providers,
disableUDP: option.DisableUDP, disableUDP: option.DisableUDP,
} }

View File

@ -65,14 +65,14 @@ func (hc *HealthCheck) touch() {
} }
func (hc *HealthCheck) check() { func (hc *HealthCheck) check() {
b, _ := batch.New(context.Background(), batch.WithConcurrencyNum(10)) b, _ := batch.New[bool](context.Background(), batch.WithConcurrencyNum[bool](10))
for _, proxy := range hc.proxies { for _, proxy := range hc.proxies {
p := proxy p := proxy
b.Go(p.Name(), func() (any, error) { b.Go(p.Name(), func() (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultURLTestTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultURLTestTimeout)
defer cancel() defer cancel()
p.URLTest(ctx, hc.url) _, _ = p.URLTest(ctx, hc.url)
return nil, nil return false, nil
}) })
} }
b.Wait() b.Wait()

View File

@ -5,10 +5,10 @@ import (
"sync" "sync"
) )
type Option = func(b *Batch) type Option[T any] func(b *Batch[T])
type Result struct { type Result[T any] struct {
Value any Value T
Err error Err error
} }
@ -17,8 +17,8 @@ type Error struct {
Err error Err error
} }
func WithConcurrencyNum(n int) Option { func WithConcurrencyNum[T any](n int) Option[T] {
return func(b *Batch) { return func(b *Batch[T]) {
q := make(chan struct{}, n) q := make(chan struct{}, n)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
q <- struct{}{} q <- struct{}{}
@ -28,8 +28,8 @@ func WithConcurrencyNum(n int) Option {
} }
// Batch similar to errgroup, but can control the maximum number of concurrent // Batch similar to errgroup, but can control the maximum number of concurrent
type Batch struct { type Batch[T any] struct {
result map[string]Result result map[string]Result[T]
queue chan struct{} queue chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
mux sync.Mutex mux sync.Mutex
@ -38,7 +38,7 @@ type Batch struct {
cancel func() cancel func()
} }
func (b *Batch) Go(key string, fn func() (any, error)) { func (b *Batch[T]) Go(key string, fn func() (T, error)) {
b.wg.Add(1) b.wg.Add(1)
go func() { go func() {
defer b.wg.Done() defer b.wg.Done()
@ -59,14 +59,14 @@ func (b *Batch) Go(key string, fn func() (any, error)) {
}) })
} }
ret := Result{value, err} ret := Result[T]{value, err}
b.mux.Lock() b.mux.Lock()
defer b.mux.Unlock() defer b.mux.Unlock()
b.result[key] = ret b.result[key] = ret
}() }()
} }
func (b *Batch) Wait() *Error { func (b *Batch[T]) Wait() *Error {
b.wg.Wait() b.wg.Wait()
if b.cancel != nil { if b.cancel != nil {
b.cancel() b.cancel()
@ -74,26 +74,26 @@ func (b *Batch) Wait() *Error {
return b.err return b.err
} }
func (b *Batch) WaitAndGetResult() (map[string]Result, *Error) { func (b *Batch[T]) WaitAndGetResult() (map[string]Result[T], *Error) {
err := b.Wait() err := b.Wait()
return b.Result(), err return b.Result(), err
} }
func (b *Batch) Result() map[string]Result { func (b *Batch[T]) Result() map[string]Result[T] {
b.mux.Lock() b.mux.Lock()
defer b.mux.Unlock() defer b.mux.Unlock()
copy := map[string]Result{} copyM := map[string]Result[T]{}
for k, v := range b.result { for k, v := range b.result {
copy[k] = v copyM[k] = v
} }
return copy return copyM
} }
func New(ctx context.Context, opts ...Option) (*Batch, context.Context) { func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
b := &Batch{ b := &Batch[T]{
result: map[string]Result{}, result: map[string]Result[T]{},
} }
for _, o := range opts { for _, o := range opts {

View File

@ -11,14 +11,14 @@ import (
) )
func TestBatch(t *testing.T) { func TestBatch(t *testing.T) {
b, _ := New(context.Background()) b, _ := New[string](context.Background())
now := time.Now() now := time.Now()
b.Go("foo", func() (any, error) { b.Go("foo", func() (string, error) {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return "foo", nil return "foo", nil
}) })
b.Go("bar", func() (any, error) { b.Go("bar", func() (string, error) {
time.Sleep(time.Millisecond * 150) time.Sleep(time.Millisecond * 150)
return "bar", nil return "bar", nil
}) })
@ -32,20 +32,20 @@ func TestBatch(t *testing.T) {
for k, v := range result { for k, v := range result {
assert.NoError(t, v.Err) assert.NoError(t, v.Err)
assert.Equal(t, k, v.Value.(string)) assert.Equal(t, k, v.Value)
} }
} }
func TestBatchWithConcurrencyNum(t *testing.T) { func TestBatchWithConcurrencyNum(t *testing.T) {
b, _ := New( b, _ := New[string](
context.Background(), context.Background(),
WithConcurrencyNum(3), WithConcurrencyNum[string](3),
) )
now := time.Now() now := time.Now()
for i := 0; i < 7; i++ { for i := 0; i < 7; i++ {
idx := i idx := i
b.Go(strconv.Itoa(idx), func() (any, error) { b.Go(strconv.Itoa(idx), func() (string, error) {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return strconv.Itoa(idx), nil return strconv.Itoa(idx), nil
}) })
@ -57,21 +57,21 @@ func TestBatchWithConcurrencyNum(t *testing.T) {
for k, v := range result { for k, v := range result {
assert.NoError(t, v.Err) assert.NoError(t, v.Err)
assert.Equal(t, k, v.Value.(string)) assert.Equal(t, k, v.Value)
} }
} }
func TestBatchContext(t *testing.T) { func TestBatchContext(t *testing.T) {
b, ctx := New(context.Background()) b, ctx := New[string](context.Background())
b.Go("error", func() (any, error) { b.Go("error", func() (string, error) {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return nil, errors.New("test error") return "", errors.New("test error")
}) })
b.Go("ctx", func() (any, error) { b.Go("ctx", func() (string, error) {
<-ctx.Done() <-ctx.Done()
return nil, ctx.Err() return "", ctx.Err()
}) })
result, err := b.WaitAndGetResult() result, err := b.WaitAndGetResult()

View File

@ -3,19 +3,20 @@ package cache
// Modified by https://github.com/die-net/lrucache // Modified by https://github.com/die-net/lrucache
import ( import (
"container/list"
"sync" "sync"
"time" "time"
"github.com/Dreamacro/clash/common/generics/list"
) )
// Option is part of Functional Options Pattern // Option is part of Functional Options Pattern
type Option[K comparable, V any] func(*LruCache[K, V]) type Option[K comparable, V any] func(*LruCache[K, V])
// EvictCallback is used to get a callback when a cache entry is evicted // EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback = func(key any, value any) type EvictCallback[K comparable, V any] func(key K, value V)
// WithEvict set the evict callback // WithEvict set the evict callback
func WithEvict[K comparable, V any](cb EvictCallback) Option[K, V] { func WithEvict[K comparable, V any](cb EvictCallback[K, V]) Option[K, V] {
return func(l *LruCache[K, V]) { return func(l *LruCache[K, V]) {
l.onEvict = cb l.onEvict = cb
} }
@ -57,18 +58,18 @@ type LruCache[K comparable, V any] struct {
maxAge int64 maxAge int64
maxSize int maxSize int
mu sync.Mutex mu sync.Mutex
cache map[any]*list.Element cache map[K]*list.Element[*entry[K, V]]
lru *list.List // Front is least-recent lru *list.List[*entry[K, V]] // Front is least-recent
updateAgeOnGet bool updateAgeOnGet bool
staleReturn bool staleReturn bool
onEvict EvictCallback onEvict EvictCallback[K, V]
} }
// NewLRUCache creates an LruCache // NewLRUCache creates an LruCache
func NewLRUCache[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] { func NewLRUCache[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] {
lc := &LruCache[K, V]{ lc := &LruCache[K, V]{
lru: list.New(), lru: list.New[*entry[K, V]](),
cache: make(map[any]*list.Element), cache: make(map[K]*list.Element[*entry[K, V]]),
} }
for _, option := range options { for _, option := range options {
@ -129,7 +130,7 @@ func (c *LruCache[K, V]) SetWithExpire(key K, value V, expires time.Time) {
if le, ok := c.cache[key]; ok { if le, ok := c.cache[key]; ok {
c.lru.MoveToBack(le) c.lru.MoveToBack(le)
e := le.Value.(*entry[K, V]) e := le.Value
e.value = value e.value = value
e.expires = expires.Unix() e.expires = expires.Unix()
} else { } else {
@ -154,11 +155,11 @@ func (c *LruCache[K, V]) CloneTo(n *LruCache[K, V]) {
n.mu.Lock() n.mu.Lock()
defer n.mu.Unlock() defer n.mu.Unlock()
n.lru = list.New() n.lru = list.New[*entry[K, V]]()
n.cache = make(map[any]*list.Element) n.cache = make(map[K]*list.Element[*entry[K, V]])
for e := c.lru.Front(); e != nil; e = e.Next() { for e := c.lru.Front(); e != nil; e = e.Next() {
elm := e.Value.(*entry[K, V]) elm := e.Value
n.cache[elm.key] = n.lru.PushBack(elm) n.cache[elm.key] = n.lru.PushBack(elm)
} }
} }
@ -172,7 +173,7 @@ func (c *LruCache[K, V]) get(key K) *entry[K, V] {
return nil return nil
} }
if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry[K, V]).expires <= time.Now().Unix() { if !c.staleReturn && c.maxAge > 0 && le.Value.expires <= time.Now().Unix() {
c.deleteElement(le) c.deleteElement(le)
c.maybeDeleteOldest() c.maybeDeleteOldest()
@ -180,7 +181,7 @@ func (c *LruCache[K, V]) get(key K) *entry[K, V] {
} }
c.lru.MoveToBack(le) c.lru.MoveToBack(le)
el := le.Value.(*entry[K, V]) el := le.Value
if c.maxAge > 0 && c.updateAgeOnGet { if c.maxAge > 0 && c.updateAgeOnGet {
el.expires = time.Now().Unix() + c.maxAge el.expires = time.Now().Unix() + c.maxAge
} }
@ -201,15 +202,15 @@ func (c *LruCache[K, V]) Delete(key K) {
func (c *LruCache[K, V]) maybeDeleteOldest() { func (c *LruCache[K, V]) maybeDeleteOldest() {
if !c.staleReturn && c.maxAge > 0 { if !c.staleReturn && c.maxAge > 0 {
now := time.Now().Unix() now := time.Now().Unix()
for le := c.lru.Front(); le != nil && le.Value.(*entry[K, V]).expires <= now; le = c.lru.Front() { for le := c.lru.Front(); le != nil && le.Value.expires <= now; le = c.lru.Front() {
c.deleteElement(le) c.deleteElement(le)
} }
} }
} }
func (c *LruCache[K, V]) deleteElement(le *list.Element) { func (c *LruCache[K, V]) deleteElement(le *list.Element[*entry[K, V]]) {
c.lru.Remove(le) c.lru.Remove(le)
e := le.Value.(*entry[K, V]) e := le.Value
delete(c.cache, e.key) delete(c.cache, e.key)
if c.onEvict != nil { if c.onEvict != nil {
c.onEvict(e.key, e.value) c.onEvict(e.key, e.value)
@ -219,7 +220,7 @@ func (c *LruCache[K, V]) deleteElement(le *list.Element) {
func (c *LruCache[K, V]) Clear() error { func (c *LruCache[K, V]) Clear() error {
c.mu.Lock() c.mu.Lock()
c.cache = make(map[any]*list.Element) c.cache = make(map[K]*list.Element[*entry[K, V]])
c.mu.Unlock() c.mu.Unlock()
return nil return nil

View File

@ -52,18 +52,18 @@ func TestLRUMaxAge(t *testing.T) {
// Add one expired entry // Add one expired entry
c.Set("foo", "bar") c.Set("foo", "bar")
c.lru.Back().Value.(*entry[string, string]).expires = now c.lru.Back().Value.expires = now
// Reset // Reset
c.Set("foo", "bar") c.Set("foo", "bar")
e := c.lru.Back().Value.(*entry[string, string]) e := c.lru.Back().Value
assert.True(t, e.expires >= now) assert.True(t, e.expires >= now)
c.lru.Back().Value.(*entry[string, string]).expires = now c.lru.Back().Value.expires = now
// Set a few and verify expiration times // Set a few and verify expiration times
for _, s := range entries { for _, s := range entries {
c.Set(s.key, s.value) c.Set(s.key, s.value)
e := c.lru.Back().Value.(*entry[string, string]) e := c.lru.Back().Value
assert.True(t, e.expires >= expected && e.expires <= expected+10) assert.True(t, e.expires >= expected && e.expires <= expected+10)
} }
@ -77,7 +77,7 @@ func TestLRUMaxAge(t *testing.T) {
for _, s := range entries { for _, s := range entries {
le, ok := c.cache[s.key] le, ok := c.cache[s.key]
if assert.True(t, ok) { if assert.True(t, ok) {
le.Value.(*entry[string, string]).expires = now le.Value.expires = now
} }
} }
@ -95,11 +95,11 @@ func TestLRUpdateOnGet(t *testing.T) {
// Add one expired entry // Add one expired entry
c.Set("foo", "bar") c.Set("foo", "bar")
c.lru.Back().Value.(*entry[string, string]).expires = expires c.lru.Back().Value.expires = expires
_, ok := c.Get("foo") _, ok := c.Get("foo")
assert.True(t, ok) assert.True(t, ok)
assert.True(t, c.lru.Back().Value.(*entry[string, string]).expires > expires) assert.True(t, c.lru.Back().Value.expires > expires)
} }
func TestMaxSize(t *testing.T) { func TestMaxSize(t *testing.T) {
@ -126,8 +126,8 @@ func TestExist(t *testing.T) {
func TestEvict(t *testing.T) { func TestEvict(t *testing.T) {
temp := 0 temp := 0
evict := func(key any, value any) { evict := func(key int, value int) {
temp = key.(int) + value.(int) temp = key + value
} }
c := NewLRUCache[int, int](WithEvict[int, int](evict), WithSize[int, int](1)) c := NewLRUCache[int, int](WithEvict[int, int](evict), WithSize[int, int](1))

View File

@ -11,6 +11,7 @@ import (
"math/big" "math/big"
"net" "net"
"os" "os"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
) )
@ -38,19 +39,6 @@ type CertsStorage interface {
Set(key string, cert *tls.Certificate) Set(key string, cert *tls.Certificate)
} }
type CertsCache struct {
certsCache map[string]*tls.Certificate
}
func (c *CertsCache) Get(key string) (*tls.Certificate, bool) {
v, ok := c.certsCache[key]
return v, ok
}
func (c *CertsCache) Set(key string, cert *tls.Certificate) {
c.certsCache[key] = cert
}
func NewAuthority(name, organization string, validity time.Duration) (*x509.Certificate, *rsa.PrivateKey, error) { func NewAuthority(name, organization string, validity time.Duration) (*x509.Certificate, *rsa.PrivateKey, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
@ -100,7 +88,7 @@ func NewAuthority(name, organization string, validity time.Duration) (*x509.Cert
return x509c, privateKey, nil return x509c, privateKey, nil
} }
func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey, storage CertsStorage) (*Config, error) { func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*Config, error) {
roots := x509.NewCertPool() roots := x509.NewCertPool()
roots.AddCert(ca) roots.AddCert(ca)
@ -121,10 +109,6 @@ func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey, storage Certs
} }
keyID := h.Sum(nil) keyID := h.Sum(nil)
if storage == nil {
storage = &CertsCache{certsCache: make(map[string]*tls.Certificate)}
}
return &Config{ return &Config{
ca: ca, ca: ca,
caPrivateKey: caPrivateKey, caPrivateKey: caPrivateKey,
@ -132,7 +116,7 @@ func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey, storage Certs
keyID: keyID, keyID: keyID,
validity: time.Hour, validity: time.Hour,
organization: "Clash", organization: "Clash",
certsStorage: storage, certsStorage: NewDomainTrieCertsStorage(),
roots: roots, roots: roots,
}, nil }, nil
} }
@ -168,14 +152,11 @@ func (c *Config) NewTLSConfigForHost(hostname string) *tls.Config {
} }
func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certificate, error) { func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certificate, error) {
host, _, err := net.SplitHostPort(hostname) var leaf *x509.Certificate
if err == nil {
hostname = host
}
tlsCertificate, ok := c.certsStorage.Get(hostname) tlsCertificate, ok := c.certsStorage.Get(hostname)
if ok { if ok {
if _, err = tlsCertificate.Leaf.Verify(x509.VerifyOptions{ leaf = tlsCertificate.Leaf
if _, err := leaf.Verify(x509.VerifyOptions{
DNSName: hostname, DNSName: hostname,
Roots: c.roots, Roots: c.roots,
}); err == nil { }); err == nil {
@ -183,12 +164,49 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica
} }
} }
var (
key = hostname
topHost = hostname
wildcardHost = "*." + hostname
dnsNames []string
)
if ip := net.ParseIP(hostname); ip != nil {
ips = append(ips, ip)
} else {
parts := strings.Split(hostname, ".")
l := len(parts)
if leaf != nil {
dnsNames = append(dnsNames, leaf.DNSNames...)
}
if l > 2 {
topIndex := l - 2
topHost = strings.Join(parts[topIndex:], ".")
for i := topIndex; i > 0; i-- {
wildcardHost = "*." + strings.Join(parts[i:], ".")
if i == topIndex && (len(dnsNames) == 0 || dnsNames[0] != topHost) {
dnsNames = append(dnsNames, topHost, wildcardHost)
} else if !hasDnsNames(dnsNames, wildcardHost) {
dnsNames = append(dnsNames, wildcardHost)
}
}
} else {
dnsNames = append(dnsNames, topHost, wildcardHost)
}
key = "+." + topHost
}
serial := atomic.AddInt64(&currentSerialNumber, 1) serial := atomic.AddInt64(&currentSerialNumber, 1)
tmpl := &x509.Certificate{ tmpl := &x509.Certificate{
SerialNumber: big.NewInt(serial), SerialNumber: big.NewInt(serial),
Subject: pkix.Name{ Subject: pkix.Name{
CommonName: hostname, CommonName: topHost,
Organization: []string{c.organization}, Organization: []string{c.organization},
}, },
SubjectKeyId: c.keyID, SubjectKeyId: c.keyID,
@ -197,16 +215,10 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica
BasicConstraintsValid: true, BasicConstraintsValid: true,
NotBefore: time.Now().Add(-c.validity), NotBefore: time.Now().Add(-c.validity),
NotAfter: time.Now().Add(c.validity), NotAfter: time.Now().Add(c.validity),
DNSNames: dnsNames,
IPAddresses: ips,
} }
if ip := net.ParseIP(hostname); ip != nil {
ips = append(ips, ip)
} else {
tmpl.DNSNames = []string{hostname}
}
tmpl.IPAddresses = ips
raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.privateKey.Public(), c.caPrivateKey) raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.privateKey.Public(), c.caPrivateKey)
if err != nil { if err != nil {
return nil, err return nil, err
@ -223,7 +235,7 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica
Leaf: x509c, Leaf: x509c,
} }
c.certsStorage.Set(hostname, tlsCertificate) c.certsStorage.Set(key, tlsCertificate)
return tlsCertificate, nil return tlsCertificate, nil
} }
@ -280,3 +292,12 @@ func GenerateAndSave(caPath string, caKeyPath string) error {
return nil return nil
} }
func hasDnsNames(dnsNames []string, hostname string) bool {
for _, name := range dnsNames {
if name == hostname {
return true
}
}
return false
}

View File

@ -18,7 +18,7 @@ func TestCert(t *testing.T) {
assert.NotNil(t, ca) assert.NotNil(t, ca)
assert.NotNil(t, privateKey) assert.NotNil(t, privateKey)
c, err := NewConfig(ca, privateKey, nil) c, err := NewConfig(ca, privateKey)
assert.Nil(t, err) assert.Nil(t, err)
c.SetValidity(20 * time.Hour) c.SetValidity(20 * time.Hour)
@ -40,27 +40,55 @@ func TestCert(t *testing.T) {
x509c := tlsCert.Leaf x509c := tlsCert.Leaf
assert.Equal(t, "example.org", x509c.Subject.CommonName) assert.Equal(t, "example.org", x509c.Subject.CommonName)
assert.Nil(t, x509c.VerifyHostname("example.org")) assert.Nil(t, x509c.VerifyHostname("example.org"))
assert.Nil(t, x509c.VerifyHostname("abc.example.org"))
assert.Equal(t, []string{"Test Organization"}, x509c.Subject.Organization) assert.Equal(t, []string{"Test Organization"}, x509c.Subject.Organization)
assert.NotNil(t, x509c.SubjectKeyId) assert.NotNil(t, x509c.SubjectKeyId)
assert.True(t, x509c.BasicConstraintsValid) assert.True(t, x509c.BasicConstraintsValid)
assert.True(t, x509c.KeyUsage&x509.KeyUsageKeyEncipherment == x509.KeyUsageKeyEncipherment) assert.True(t, x509c.KeyUsage&x509.KeyUsageKeyEncipherment == x509.KeyUsageKeyEncipherment)
assert.True(t, x509c.KeyUsage&x509.KeyUsageDigitalSignature == x509.KeyUsageDigitalSignature) assert.True(t, x509c.KeyUsage&x509.KeyUsageDigitalSignature == x509.KeyUsageDigitalSignature)
assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, x509c.ExtKeyUsage) assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, x509c.ExtKeyUsage)
assert.Equal(t, []string{"example.org"}, x509c.DNSNames) assert.Equal(t, []string{"example.org", "*.example.org"}, x509c.DNSNames)
assert.True(t, x509c.NotBefore.Before(time.Now().Add(-2*time.Hour))) assert.True(t, x509c.NotBefore.Before(time.Now().Add(-2*time.Hour)))
assert.True(t, x509c.NotAfter.After(time.Now().Add(2*time.Hour))) assert.True(t, x509c.NotAfter.After(time.Now().Add(2*time.Hour)))
// Check that certificate is cached // Check that certificate is cached
tlsCert2, err := c.GetOrCreateCert("example.org") tlsCert2, err := c.GetOrCreateCert("abc.example.org")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, tlsCert == tlsCert2) assert.True(t, tlsCert == tlsCert2)
// Check the certificate for an IP // Check that certificate is new
tlsCertForIP, err := c.GetOrCreateCert("192.168.0.1:443") _, _ = c.GetOrCreateCert("a.b.c.d.e.f.g.h.i.j.example.org")
tlsCert3, err := c.GetOrCreateCert("m.k.l.example.org")
x509c = tlsCert3.Leaf
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, tlsCert == tlsCert3)
assert.Equal(t, []string{"example.org", "*.example.org", "*.j.example.org", "*.i.j.example.org", "*.h.i.j.example.org", "*.g.h.i.j.example.org", "*.f.g.h.i.j.example.org", "*.e.f.g.h.i.j.example.org", "*.d.e.f.g.h.i.j.example.org", "*.c.d.e.f.g.h.i.j.example.org", "*.b.c.d.e.f.g.h.i.j.example.org", "*.l.example.org", "*.k.l.example.org"}, x509c.DNSNames)
// Check that certificate is cached
tlsCert4, err := c.GetOrCreateCert("xyz.example.org")
x509c = tlsCert4.Leaf
assert.Nil(t, err)
assert.True(t, tlsCert3 == tlsCert4)
assert.Nil(t, x509c.VerifyHostname("example.org"))
assert.Nil(t, x509c.VerifyHostname("jkf.example.org"))
assert.Nil(t, x509c.VerifyHostname("n.j.example.org"))
assert.Nil(t, x509c.VerifyHostname("c.i.j.example.org"))
assert.Nil(t, x509c.VerifyHostname("m.l.example.org"))
assert.Error(t, x509c.VerifyHostname("m.l.jkf.example.org"))
// Check the certificate for an IP
tlsCertForIP, err := c.GetOrCreateCert("192.168.0.1")
x509c = tlsCertForIP.Leaf x509c = tlsCertForIP.Leaf
assert.Nil(t, err)
assert.Equal(t, 1, len(x509c.IPAddresses)) assert.Equal(t, 1, len(x509c.IPAddresses))
assert.True(t, net.ParseIP("192.168.0.1").Equal(x509c.IPAddresses[0])) assert.True(t, net.ParseIP("192.168.0.1").Equal(x509c.IPAddresses[0]))
// Check that certificate is cached
tlsCertForIP2, err := c.GetOrCreateCert("192.168.0.1")
x509c = tlsCertForIP2.Leaf
assert.Nil(t, err)
assert.True(t, tlsCertForIP == tlsCertForIP2)
assert.Nil(t, x509c.VerifyHostname("192.168.0.1"))
} }
func TestGenerateAndSave(t *testing.T) { func TestGenerateAndSave(t *testing.T) {

View File

@ -2,31 +2,31 @@ package cert
import ( import (
"crypto/tls" "crypto/tls"
"time"
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/component/trie"
) )
var TTL = time.Hour * 2 // DomainTrieCertsStorage cache wildcard certificates
type DomainTrieCertsStorage struct {
// AutoGCCertsStorage cache with the generated certificates, auto released after TTL certsCache *trie.DomainTrie[*tls.Certificate]
type AutoGCCertsStorage struct {
certsCache *cache.Cache[string, *tls.Certificate]
} }
// Get gets the certificate from the storage // Get gets the certificate from the storage
func (c *AutoGCCertsStorage) Get(key string) (*tls.Certificate, bool) { func (c *DomainTrieCertsStorage) Get(key string) (*tls.Certificate, bool) {
ca := c.certsCache.Get(key) ca := c.certsCache.Search(key)
return ca, ca != nil if ca == nil {
return nil, false
}
return ca.Data, true
} }
// Set saves the certificate to the storage // Set saves the certificate to the storage
func (c *AutoGCCertsStorage) Set(key string, cert *tls.Certificate) { func (c *DomainTrieCertsStorage) Set(key string, cert *tls.Certificate) {
c.certsCache.Put(key, cert, TTL) _ = c.certsCache.Insert(key, cert)
} }
func NewAutoGCCertsStorage() *AutoGCCertsStorage { func NewDomainTrieCertsStorage() *DomainTrieCertsStorage {
return &AutoGCCertsStorage{ return &DomainTrieCertsStorage{
certsCache: cache.New[string, *tls.Certificate](TTL), certsCache: trie.New[*tls.Certificate](),
} }
} }

View File

@ -0,0 +1,235 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package list implements a doubly linked list.
//
// To iterate over a list (where l is a *List):
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.Value
// }
//
package list
// Element is an element of a linked list.
type Element[T any] struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Element[T]
// The list to which this element belongs.
list *List[T]
// The value stored with this element.
Value T
}
// Next returns the next list element or nil.
func (e *Element[T]) Next() *Element[T] {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *Element[T]) Prev() *Element[T] {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// List represents a doubly linked list.
// The zero value for List is an empty list ready to use.
type List[T any] struct {
root Element[T] // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *List[T]) Init() *List[T] {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// New returns an initialized list.
func New[T any]() *List[T] { return new(List[T]).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *List[T]) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *List[T]) Front() *Element[T] {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *List[T]) Back() *Element[T] {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *List[T]) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *List[T]) insert(e, at *Element[T]) *Element[T] {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] {
return l.insert(&Element[T]{Value: v}, at)
}
// remove removes e from its list, decrements l.len
func (l *List[T]) remove(e *Element[T]) {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
}
// move moves e to next to at.
func (l *List[T]) move(e, at *Element[T]) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *List[T]) Remove(e *Element[T]) T {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *List[T]) PushFront(v T) *Element[T] {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *List[T]) PushBack(v T) *Element[T] {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *List[T]) MoveToFront(e *Element[T]) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *List[T]) MoveToBack(e *Element[T]) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *List[T]) MoveBefore(e, mark *Element[T]) {
if e.list != l || e == mark || mark.list != l {
return
}
l.move(e, mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *List[T]) MoveAfter(e, mark *Element[T]) {
if e.list != l || e == mark || mark.list != l {
return
}
l.move(e, mark)
}
// PushBackList inserts a copy of another list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *List[T]) PushBackList(other *List[T]) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of another list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *List[T]) PushFrontList(other *List[T]) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}

39
common/net/relay.go Normal file
View File

@ -0,0 +1,39 @@
package net
import (
"io"
"net"
"time"
"github.com/Dreamacro/clash/common/pool"
)
// Relay copies between left and right bidirectionally.
func Relay(leftConn, rightConn net.Conn) {
ch := make(chan error)
tcpKeepAlive(leftConn)
tcpKeepAlive(rightConn)
go func() {
buf := pool.Get(pool.RelayBufferSize)
// Wrapping to avoid using *net.TCPConn.(ReadFrom)
// See also https://github.com/Dreamacro/clash/pull/1209
_, err := io.CopyBuffer(WriteOnlyWriter{Writer: leftConn}, ReadOnlyReader{Reader: rightConn}, buf)
_ = pool.Put(buf)
_ = leftConn.SetReadDeadline(time.Now())
ch <- err
}()
buf := pool.Get(pool.RelayBufferSize)
_, _ = io.CopyBuffer(WriteOnlyWriter{Writer: rightConn}, ReadOnlyReader{Reader: leftConn}, buf)
_ = pool.Put(buf)
_ = rightConn.SetReadDeadline(time.Now())
<-ch
}
func tcpKeepAlive(c net.Conn) {
if tcp, ok := c.(*net.TCPConn); ok {
_ = tcp.SetKeepAlive(true)
}
}

View File

@ -1,3 +1,3 @@
package observable package observable
type Iterable <-chan any type Iterable[T any] <-chan T

View File

@ -5,14 +5,14 @@ import (
"sync" "sync"
) )
type Observable struct { type Observable[T any] struct {
iterable Iterable iterable Iterable[T]
listener map[Subscription]*Subscriber listener map[Subscription[T]]*Subscriber[T]
mux sync.Mutex mux sync.Mutex
done bool done bool
} }
func (o *Observable) process() { func (o *Observable[T]) process() {
for item := range o.iterable { for item := range o.iterable {
o.mux.Lock() o.mux.Lock()
for _, sub := range o.listener { for _, sub := range o.listener {
@ -23,7 +23,7 @@ func (o *Observable) process() {
o.close() o.close()
} }
func (o *Observable) close() { func (o *Observable[T]) close() {
o.mux.Lock() o.mux.Lock()
defer o.mux.Unlock() defer o.mux.Unlock()
@ -33,18 +33,18 @@ func (o *Observable) close() {
} }
} }
func (o *Observable) Subscribe() (Subscription, error) { func (o *Observable[T]) Subscribe() (Subscription[T], error) {
o.mux.Lock() o.mux.Lock()
defer o.mux.Unlock() defer o.mux.Unlock()
if o.done { if o.done {
return nil, errors.New("Observable is closed") return nil, errors.New("observable is closed")
} }
subscriber := newSubscriber() subscriber := newSubscriber[T]()
o.listener[subscriber.Out()] = subscriber o.listener[subscriber.Out()] = subscriber
return subscriber.Out(), nil return subscriber.Out(), nil
} }
func (o *Observable) UnSubscribe(sub Subscription) { func (o *Observable[T]) UnSubscribe(sub Subscription[T]) {
o.mux.Lock() o.mux.Lock()
defer o.mux.Unlock() defer o.mux.Unlock()
subscriber, exist := o.listener[sub] subscriber, exist := o.listener[sub]
@ -55,10 +55,10 @@ func (o *Observable) UnSubscribe(sub Subscription) {
subscriber.Close() subscriber.Close()
} }
func NewObservable(any Iterable) *Observable { func NewObservable[T any](iter Iterable[T]) *Observable[T] {
observable := &Observable{ observable := &Observable[T]{
iterable: any, iterable: iter,
listener: map[Subscription]*Subscriber{}, listener: map[Subscription[T]]*Subscriber[T]{},
} }
go observable.process() go observable.process()
return observable return observable

View File

@ -9,8 +9,8 @@ import (
"go.uber.org/atomic" "go.uber.org/atomic"
) )
func iterator(item []any) chan any { func iterator[T any](item []T) chan T {
ch := make(chan any) ch := make(chan T)
go func() { go func() {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
for _, elm := range item { for _, elm := range item {
@ -22,8 +22,8 @@ func iterator(item []any) chan any {
} }
func TestObservable(t *testing.T) { func TestObservable(t *testing.T) {
iter := iterator([]any{1, 2, 3, 4, 5}) iter := iterator[int]([]int{1, 2, 3, 4, 5})
src := NewObservable(iter) src := NewObservable[int](iter)
data, err := src.Subscribe() data, err := src.Subscribe()
assert.Nil(t, err) assert.Nil(t, err)
count := 0 count := 0
@ -34,15 +34,15 @@ func TestObservable(t *testing.T) {
} }
func TestObservable_MultiSubscribe(t *testing.T) { func TestObservable_MultiSubscribe(t *testing.T) {
iter := iterator([]any{1, 2, 3, 4, 5}) iter := iterator[int]([]int{1, 2, 3, 4, 5})
src := NewObservable(iter) src := NewObservable[int](iter)
ch1, _ := src.Subscribe() ch1, _ := src.Subscribe()
ch2, _ := src.Subscribe() ch2, _ := src.Subscribe()
count := atomic.NewInt32(0) count := atomic.NewInt32(0)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
waitCh := func(ch <-chan any) { waitCh := func(ch <-chan int) {
for range ch { for range ch {
count.Inc() count.Inc()
} }
@ -55,8 +55,8 @@ func TestObservable_MultiSubscribe(t *testing.T) {
} }
func TestObservable_UnSubscribe(t *testing.T) { func TestObservable_UnSubscribe(t *testing.T) {
iter := iterator([]any{1, 2, 3, 4, 5}) iter := iterator[int]([]int{1, 2, 3, 4, 5})
src := NewObservable(iter) src := NewObservable[int](iter)
data, err := src.Subscribe() data, err := src.Subscribe()
assert.Nil(t, err) assert.Nil(t, err)
src.UnSubscribe(data) src.UnSubscribe(data)
@ -65,8 +65,8 @@ func TestObservable_UnSubscribe(t *testing.T) {
} }
func TestObservable_SubscribeClosedSource(t *testing.T) { func TestObservable_SubscribeClosedSource(t *testing.T) {
iter := iterator([]any{1}) iter := iterator[int]([]int{1})
src := NewObservable(iter) src := NewObservable[int](iter)
data, _ := src.Subscribe() data, _ := src.Subscribe()
<-data <-data
@ -75,18 +75,18 @@ func TestObservable_SubscribeClosedSource(t *testing.T) {
} }
func TestObservable_UnSubscribeWithNotExistSubscription(t *testing.T) { func TestObservable_UnSubscribeWithNotExistSubscription(t *testing.T) {
sub := Subscription(make(chan any)) sub := Subscription[int](make(chan int))
iter := iterator([]any{1}) iter := iterator[int]([]int{1})
src := NewObservable(iter) src := NewObservable[int](iter)
src.UnSubscribe(sub) src.UnSubscribe(sub)
} }
func TestObservable_SubscribeGoroutineLeak(t *testing.T) { func TestObservable_SubscribeGoroutineLeak(t *testing.T) {
iter := iterator([]any{1, 2, 3, 4, 5}) iter := iterator[int]([]int{1, 2, 3, 4, 5})
src := NewObservable(iter) src := NewObservable[int](iter)
max := 100 max := 100
var list []Subscription var list []Subscription[int]
for i := 0; i < max; i++ { for i := 0; i < max; i++ {
ch, _ := src.Subscribe() ch, _ := src.Subscribe()
list = append(list, ch) list = append(list, ch)
@ -94,7 +94,7 @@ func TestObservable_SubscribeGoroutineLeak(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(max) wg.Add(max)
waitCh := func(ch <-chan any) { waitCh := func(ch <-chan int) {
for range ch { for range ch {
} }
wg.Done() wg.Done()
@ -115,11 +115,11 @@ func TestObservable_SubscribeGoroutineLeak(t *testing.T) {
} }
func Benchmark_Observable_1000(b *testing.B) { func Benchmark_Observable_1000(b *testing.B) {
ch := make(chan any) ch := make(chan int)
o := NewObservable(ch) o := NewObservable[int](ch)
num := 1000 num := 1000
subs := []Subscription{} subs := []Subscription[int]{}
for i := 0; i < num; i++ { for i := 0; i < num; i++ {
sub, _ := o.Subscribe() sub, _ := o.Subscribe()
subs = append(subs, sub) subs = append(subs, sub)
@ -130,7 +130,7 @@ func Benchmark_Observable_1000(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for _, sub := range subs { for _, sub := range subs {
go func(s Subscription) { go func(s Subscription[int]) {
for range s { for range s {
} }
wg.Done() wg.Done()

View File

@ -4,30 +4,30 @@ import (
"sync" "sync"
) )
type Subscription <-chan any type Subscription[T any] <-chan T
type Subscriber struct { type Subscriber[T any] struct {
buffer chan any buffer chan T
once sync.Once once sync.Once
} }
func (s *Subscriber) Emit(item any) { func (s *Subscriber[T]) Emit(item T) {
s.buffer <- item s.buffer <- item
} }
func (s *Subscriber) Out() Subscription { func (s *Subscriber[T]) Out() Subscription[T] {
return s.buffer return s.buffer
} }
func (s *Subscriber) Close() { func (s *Subscriber[T]) Close() {
s.once.Do(func() { s.once.Do(func() {
close(s.buffer) close(s.buffer)
}) })
} }
func newSubscriber() *Subscriber { func newSubscriber[T any]() *Subscriber[T] {
sub := &Subscriber{ sub := &Subscriber[T]{
buffer: make(chan any, 200), buffer: make(chan T, 200),
} }
return sub return sub
} }

View File

@ -9,7 +9,7 @@ import (
// Picker provides synchronization, and Context cancelation // Picker provides synchronization, and Context cancelation
// for groups of goroutines working on subtasks of a common task. // for groups of goroutines working on subtasks of a common task.
// Inspired by errGroup // Inspired by errGroup
type Picker struct { type Picker[T any] struct {
ctx context.Context ctx context.Context
cancel func() cancel func()
@ -17,12 +17,12 @@ type Picker struct {
once sync.Once once sync.Once
errOnce sync.Once errOnce sync.Once
result any result T
err error err error
} }
func newPicker(ctx context.Context, cancel func()) *Picker { func newPicker[T any](ctx context.Context, cancel func()) *Picker[T] {
return &Picker{ return &Picker[T]{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }
@ -30,20 +30,20 @@ func newPicker(ctx context.Context, cancel func()) *Picker {
// WithContext returns a new Picker and an associated Context derived from ctx. // WithContext returns a new Picker and an associated Context derived from ctx.
// and cancel when first element return. // and cancel when first element return.
func WithContext(ctx context.Context) (*Picker, context.Context) { func WithContext[T any](ctx context.Context) (*Picker[T], context.Context) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return newPicker(ctx, cancel), ctx return newPicker[T](ctx, cancel), ctx
} }
// WithTimeout returns a new Picker and an associated Context derived from ctx with timeout. // WithTimeout returns a new Picker and an associated Context derived from ctx with timeout.
func WithTimeout(ctx context.Context, timeout time.Duration) (*Picker, context.Context) { func WithTimeout[T any](ctx context.Context, timeout time.Duration) (*Picker[T], context.Context) {
ctx, cancel := context.WithTimeout(ctx, timeout) ctx, cancel := context.WithTimeout(ctx, timeout)
return newPicker(ctx, cancel), ctx return newPicker[T](ctx, cancel), ctx
} }
// Wait blocks until all function calls from the Go method have returned, // Wait blocks until all function calls from the Go method have returned,
// then returns the first nil error result (if any) from them. // then returns the first nil error result (if any) from them.
func (p *Picker) Wait() any { func (p *Picker[T]) Wait() T {
p.wg.Wait() p.wg.Wait()
if p.cancel != nil { if p.cancel != nil {
p.cancel() p.cancel()
@ -52,13 +52,13 @@ func (p *Picker) Wait() any {
} }
// Error return the first error (if all success return nil) // Error return the first error (if all success return nil)
func (p *Picker) Error() error { func (p *Picker[T]) Error() error {
return p.err return p.err
} }
// Go calls the given function in a new goroutine. // Go calls the given function in a new goroutine.
// The first call to return a nil error cancels the group; its result will be returned by Wait. // The first call to return a nil error cancels the group; its result will be returned by Wait.
func (p *Picker) Go(f func() (any, error)) { func (p *Picker[T]) Go(f func() (T, error)) {
p.wg.Add(1) p.wg.Add(1)
go func() { go func() {

View File

@ -8,33 +8,38 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func sleepAndSend(ctx context.Context, delay int, input any) func() (any, error) { func sleepAndSend[T any](ctx context.Context, delay int, input T) func() (T, error) {
return func() (any, error) { return func() (T, error) {
timer := time.NewTimer(time.Millisecond * time.Duration(delay)) timer := time.NewTimer(time.Millisecond * time.Duration(delay))
select { select {
case <-timer.C: case <-timer.C:
return input, nil return input, nil
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return getZero[T](), ctx.Err()
} }
} }
} }
func TestPicker_Basic(t *testing.T) { func TestPicker_Basic(t *testing.T) {
picker, ctx := WithContext(context.Background()) picker, ctx := WithContext[int](context.Background())
picker.Go(sleepAndSend(ctx, 30, 2)) picker.Go(sleepAndSend(ctx, 30, 2))
picker.Go(sleepAndSend(ctx, 20, 1)) picker.Go(sleepAndSend(ctx, 20, 1))
number := picker.Wait() number := picker.Wait()
assert.NotNil(t, number) assert.NotNil(t, number)
assert.Equal(t, number.(int), 1) assert.Equal(t, number, 1)
} }
func TestPicker_Timeout(t *testing.T) { func TestPicker_Timeout(t *testing.T) {
picker, ctx := WithTimeout(context.Background(), time.Millisecond*5) picker, ctx := WithTimeout[int](context.Background(), time.Millisecond*5)
picker.Go(sleepAndSend(ctx, 20, 1)) picker.Go(sleepAndSend(ctx, 20, 1))
number := picker.Wait() number := picker.Wait()
assert.Nil(t, number) assert.Equal(t, number, getZero[int]())
assert.NotNil(t, picker.Error()) assert.NotNil(t, picker.Error())
} }
func getZero[T any]() T {
var result T
return result
}

View File

@ -5,28 +5,28 @@ import (
"time" "time"
) )
type call struct { type call[T any] struct {
wg sync.WaitGroup wg sync.WaitGroup
val any val T
err error err error
} }
type Single struct { type Single[T any] struct {
mux sync.Mutex mux sync.Mutex
last time.Time last time.Time
wait time.Duration wait time.Duration
call *call call *call[T]
result *Result result *Result[T]
} }
type Result struct { type Result[T any] struct {
Val any Val T
Err error Err error
} }
// Do single.Do likes sync.singleFlight // Do single.Do likes sync.singleFlight
//lint:ignore ST1008 it likes sync.singleFlight //lint:ignore ST1008 it likes sync.singleFlight
func (s *Single) Do(fn func() (any, error)) (v any, err error, shared bool) { func (s *Single[T]) Do(fn func() (T, error)) (v T, err error, shared bool) {
s.mux.Lock() s.mux.Lock()
now := time.Now() now := time.Now()
if now.Before(s.last.Add(s.wait)) { if now.Before(s.last.Add(s.wait)) {
@ -34,31 +34,31 @@ func (s *Single) Do(fn func() (any, error)) (v any, err error, shared bool) {
return s.result.Val, s.result.Err, true return s.result.Val, s.result.Err, true
} }
if call := s.call; call != nil { if callM := s.call; callM != nil {
s.mux.Unlock() s.mux.Unlock()
call.wg.Wait() callM.wg.Wait()
return call.val, call.err, true return callM.val, callM.err, true
} }
call := &call{} callM := &call[T]{}
call.wg.Add(1) callM.wg.Add(1)
s.call = call s.call = callM
s.mux.Unlock() s.mux.Unlock()
call.val, call.err = fn() callM.val, callM.err = fn()
call.wg.Done() callM.wg.Done()
s.mux.Lock() s.mux.Lock()
s.call = nil s.call = nil
s.result = &Result{call.val, call.err} s.result = &Result[T]{callM.val, callM.err}
s.last = now s.last = now
s.mux.Unlock() s.mux.Unlock()
return call.val, call.err, false return callM.val, callM.err, false
} }
func (s *Single) Reset() { func (s *Single[T]) Reset() {
s.last = time.Time{} s.last = time.Time{}
} }
func NewSingle(wait time.Duration) *Single { func NewSingle[T any](wait time.Duration) *Single[T] {
return &Single{wait: wait} return &Single[T]{wait: wait}
} }

View File

@ -10,13 +10,13 @@ import (
) )
func TestBasic(t *testing.T) { func TestBasic(t *testing.T) {
single := NewSingle(time.Millisecond * 30) single := NewSingle[int](time.Millisecond * 30)
foo := 0 foo := 0
shardCount := atomic.NewInt32(0) shardCount := atomic.NewInt32(0)
call := func() (any, error) { call := func() (int, error) {
foo++ foo++
time.Sleep(time.Millisecond * 5) time.Sleep(time.Millisecond * 5)
return nil, nil return 0, nil
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@ -38,32 +38,32 @@ func TestBasic(t *testing.T) {
} }
func TestTimer(t *testing.T) { func TestTimer(t *testing.T) {
single := NewSingle(time.Millisecond * 30) single := NewSingle[int](time.Millisecond * 30)
foo := 0 foo := 0
call := func() (any, error) { callM := func() (int, error) {
foo++ foo++
return nil, nil return 0, nil
} }
single.Do(call) _, _, _ = single.Do(callM)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
_, _, shard := single.Do(call) _, _, shard := single.Do(callM)
assert.Equal(t, 1, foo) assert.Equal(t, 1, foo)
assert.True(t, shard) assert.True(t, shard)
} }
func TestReset(t *testing.T) { func TestReset(t *testing.T) {
single := NewSingle(time.Millisecond * 30) single := NewSingle[int](time.Millisecond * 30)
foo := 0 foo := 0
call := func() (any, error) { callM := func() (int, error) {
foo++ foo++
return nil, nil return 0, nil
} }
single.Do(call) _, _, _ = single.Do(callM)
single.Reset() single.Reset()
single.Do(call) _, _, _ = single.Do(callM)
assert.Equal(t, 2, foo) assert.Equal(t, 2, foo)
} }

View File

@ -1,7 +1,7 @@
package strmatcher package strmatcher
import ( import (
"container/list" "github.com/Dreamacro/clash/common/generics/list"
) )
const validCharCount = 53 const validCharCount = 53
@ -190,7 +190,7 @@ func (ac *ACAutomaton) Add(domain string, t Type) {
} }
func (ac *ACAutomaton) Build() { func (ac *ACAutomaton) Build() {
queue := list.New() queue := list.New[Edge]()
for i := 0; i < validCharCount; i++ { for i := 0; i < validCharCount; i++ {
if ac.trie[0][i].nextNode != 0 { if ac.trie[0][i].nextNode != 0 {
queue.PushBack(ac.trie[0][i]) queue.PushBack(ac.trie[0][i])
@ -201,7 +201,7 @@ func (ac *ACAutomaton) Build() {
if front == nil { if front == nil {
break break
} else { } else {
node := front.Value.(Edge).nextNode node := front.Value.nextNode
queue.Remove(front) queue.Remove(front)
for i := 0; i < validCharCount; i++ { for i := 0; i < validCharCount; i++ {
if ac.trie[node][i].nextNode != 0 { if ac.trie[node][i].nextNode != 0 {

View File

@ -21,10 +21,10 @@ var (
ErrAddrNotFound = errors.New("addr not found") ErrAddrNotFound = errors.New("addr not found")
) )
var interfaces = singledo.NewSingle(time.Second * 20) var interfaces = singledo.NewSingle[map[string]*Interface](time.Second * 20)
func ResolveInterface(name string) (*Interface, error) { func ResolveInterface(name string) (*Interface, error) {
value, err, _ := interfaces.Do(func() (any, error) { value, err, _ := interfaces.Do(func() (map[string]*Interface, error) {
ifaces, err := net.Interfaces() ifaces, err := net.Interfaces()
if err != nil { if err != nil {
return nil, err return nil, err
@ -66,7 +66,7 @@ func ResolveInterface(name string) (*Interface, error) {
return nil, err return nil, err
} }
ifaces := value.(map[string]*Interface) ifaces := value
iface, ok := ifaces[name] iface, ok := ifaces[name]
if !ok { if !ok {
return nil, ErrIfaceNotFound return nil, ErrIfaceNotFound

View File

@ -6,55 +6,55 @@ import (
"time" "time"
) )
type Factory = func(context.Context) (any, error) type Factory[T any] func(context.Context) (T, error)
type entry struct { type entry[T any] struct {
elm any elm T
time time.Time time time.Time
} }
type Option func(*pool) type Option[T any] func(*pool[T])
// WithEvict set the evict callback // WithEvict set the evict callback
func WithEvict(cb func(any)) Option { func WithEvict[T any](cb func(T)) Option[T] {
return func(p *pool) { return func(p *pool[T]) {
p.evict = cb p.evict = cb
} }
} }
// WithAge defined element max age (millisecond) // WithAge defined element max age (millisecond)
func WithAge(maxAge int64) Option { func WithAge[T any](maxAge int64) Option[T] {
return func(p *pool) { return func(p *pool[T]) {
p.maxAge = maxAge p.maxAge = maxAge
} }
} }
// WithSize defined max size of Pool // WithSize defined max size of Pool
func WithSize(maxSize int) Option { func WithSize[T any](maxSize int) Option[T] {
return func(p *pool) { return func(p *pool[T]) {
p.ch = make(chan any, maxSize) p.ch = make(chan *entry[T], maxSize)
} }
} }
// Pool is for GC, see New for detail // Pool is for GC, see New for detail
type Pool struct { type Pool[T any] struct {
*pool *pool[T]
} }
type pool struct { type pool[T any] struct {
ch chan any ch chan *entry[T]
factory Factory factory Factory[T]
evict func(any) evict func(T)
maxAge int64 maxAge int64
} }
func (p *pool) GetContext(ctx context.Context) (any, error) { func (p *pool[T]) GetContext(ctx context.Context) (T, error) {
now := time.Now() now := time.Now()
for { for {
select { select {
case item := <-p.ch: case item := <-p.ch:
elm := item.(*entry) elm := item
if p.maxAge != 0 && now.Sub(item.(*entry).time).Milliseconds() > p.maxAge { if p.maxAge != 0 && now.Sub(item.time).Milliseconds() > p.maxAge {
if p.evict != nil { if p.evict != nil {
p.evict(elm.elm) p.evict(elm.elm)
} }
@ -68,12 +68,12 @@ func (p *pool) GetContext(ctx context.Context) (any, error) {
} }
} }
func (p *pool) Get() (any, error) { func (p *pool[T]) Get() (T, error) {
return p.GetContext(context.Background()) return p.GetContext(context.Background())
} }
func (p *pool) Put(item any) { func (p *pool[T]) Put(item T) {
e := &entry{ e := &entry[T]{
elm: item, elm: item,
time: time.Now(), time: time.Now(),
} }
@ -90,17 +90,17 @@ func (p *pool) Put(item any) {
} }
} }
func recycle(p *Pool) { func recycle[T any](p *Pool[T]) {
for item := range p.pool.ch { for item := range p.pool.ch {
if p.pool.evict != nil { if p.pool.evict != nil {
p.pool.evict(item.(*entry).elm) p.pool.evict(item.elm)
} }
} }
} }
func New(factory Factory, options ...Option) *Pool { func New[T any](factory Factory[T], options ...Option[T]) *Pool[T] {
p := &pool{ p := &pool[T]{
ch: make(chan any, 10), ch: make(chan *entry[T], 10),
factory: factory, factory: factory,
} }
@ -108,7 +108,7 @@ func New(factory Factory, options ...Option) *Pool {
option(p) option(p)
} }
P := &Pool{p} P := &Pool[T]{p}
runtime.SetFinalizer(P, recycle) runtime.SetFinalizer(P, recycle[T])
return P return P
} }

View File

@ -8,9 +8,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func lg() Factory { func lg() Factory[int] {
initial := -1 initial := -1
return func(context.Context) (any, error) { return func(context.Context) (int, error) {
initial++ initial++
return initial, nil return initial, nil
} }
@ -18,23 +18,23 @@ func lg() Factory {
func TestPool_Basic(t *testing.T) { func TestPool_Basic(t *testing.T) {
g := lg() g := lg()
pool := New(g) pool := New[int](g)
elm, _ := pool.Get() elm, _ := pool.Get()
assert.Equal(t, 0, elm.(int)) assert.Equal(t, 0, elm)
pool.Put(elm) pool.Put(elm)
elm, _ = pool.Get() elm, _ = pool.Get()
assert.Equal(t, 0, elm.(int)) assert.Equal(t, 0, elm)
elm, _ = pool.Get() elm, _ = pool.Get()
assert.Equal(t, 1, elm.(int)) assert.Equal(t, 1, elm)
} }
func TestPool_MaxSize(t *testing.T) { func TestPool_MaxSize(t *testing.T) {
g := lg() g := lg()
size := 5 size := 5
pool := New(g, WithSize(size)) pool := New[int](g, WithSize[int](size))
var items []any var items []int
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
item, _ := pool.Get() item, _ := pool.Get()
@ -42,7 +42,7 @@ func TestPool_MaxSize(t *testing.T) {
} }
extra, _ := pool.Get() extra, _ := pool.Get()
assert.Equal(t, size, extra.(int)) assert.Equal(t, size, extra)
for _, item := range items { for _, item := range items {
pool.Put(item) pool.Put(item)
@ -52,22 +52,22 @@ func TestPool_MaxSize(t *testing.T) {
for _, item := range items { for _, item := range items {
elm, _ := pool.Get() elm, _ := pool.Get()
assert.Equal(t, item.(int), elm.(int)) assert.Equal(t, item, elm)
} }
} }
func TestPool_MaxAge(t *testing.T) { func TestPool_MaxAge(t *testing.T) {
g := lg() g := lg()
pool := New(g, WithAge(20)) pool := New[int](g, WithAge[int](20))
elm, _ := pool.Get() elm, _ := pool.Get()
pool.Put(elm) pool.Put(elm)
elm, _ = pool.Get() elm, _ = pool.Get()
assert.Equal(t, 0, elm.(int)) assert.Equal(t, 0, elm)
pool.Put(elm) pool.Put(elm)
time.Sleep(time.Millisecond * 22) time.Sleep(time.Millisecond * 22)
elm, _ = pool.Get() elm, _ = pool.Get()
assert.Equal(t, 1, elm.(int)) assert.Equal(t, 1, elm)
} }

View File

@ -152,10 +152,10 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M
} }
func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {
fast, ctx := picker.WithTimeout(ctx, resolver.DefaultDNSTimeout) fast, ctx := picker.WithTimeout[*D.Msg](ctx, resolver.DefaultDNSTimeout)
for _, client := range clients { for _, client := range clients {
r := client r := client
fast.Go(func() (any, error) { fast.Go(func() (*D.Msg, error) {
m, err := r.ExchangeContext(ctx, m) m, err := r.ExchangeContext(ctx, m)
if err != nil { if err != nil {
return nil, err return nil, err
@ -175,7 +175,7 @@ func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.
return nil, err return nil, err
} }
msg = elm.(*D.Msg) msg = elm
return return
} }

6
go.mod
View File

@ -19,16 +19,16 @@ require (
go.uber.org/atomic v1.9.0 go.uber.org/atomic v1.9.0
go.uber.org/automaxprocs v1.5.1 go.uber.org/automaxprocs v1.5.1
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4
golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 golang.org/x/net v0.0.0-20220421235706-1d1ef9303861
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad golang.org/x/sys v0.0.0-20220422013727-9388b58f7150
golang.org/x/text v0.3.8-0.20220124021120-d1c84af989ab golang.org/x/text v0.3.8-0.20220124021120-d1c84af989ab
golang.org/x/time v0.0.0-20220411224347-583f2d630306 golang.org/x/time v0.0.0-20220411224347-583f2d630306
golang.zx2c4.com/wireguard v0.0.0-20220318042302-193cf8d6a5d6 golang.zx2c4.com/wireguard v0.0.0-20220318042302-193cf8d6a5d6
golang.zx2c4.com/wireguard/windows v0.5.4-0.20220317000008-6432784c2469 golang.zx2c4.com/wireguard/windows v0.5.4-0.20220317000008-6432784c2469
google.golang.org/protobuf v1.28.0 google.golang.org/protobuf v1.28.0
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4
) )
require ( require (

12
go.sum
View File

@ -97,8 +97,8 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 h1:6mzvA99KwZxbOrxww4EvWVQUnN1+xEu9tafK5ZxkYeA= golang.org/x/net v0.0.0-20220421235706-1d1ef9303861 h1:yssD99+7tqHWO5Gwh81phT+67hg+KttniBr6UnEXOY8=
golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220421235706-1d1ef9303861/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -122,8 +122,8 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16CMAGuqwO2lX1mTyyRRc=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@ -159,5 +159,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b h1:zBJp2eKSoNIV6+9LO3bRhlnuK280Oyrwc6OeFIN6VzU= gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4 h1:CSkd548jw5hmVwdJ+JuUhMtRV56oQBER7sbkIOePP2Y=
gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI= gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI=

View File

@ -2,11 +2,9 @@ package executor
import ( import (
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"os" "os"
"runtime" "runtime"
"strconv"
"sync" "sync"
"github.com/Dreamacro/clash/adapter" "github.com/Dreamacro/clash/adapter"
@ -317,13 +315,7 @@ func updateIPTables(cfg *config.Config) {
return return
} }
_, dnsPortStr, err := net.SplitHostPort(dnsCfg.Listen) dnsPort, err := netip.ParseAddrPort(dnsCfg.Listen)
if err != nil {
err = fmt.Errorf("DNS server must be enable")
return
}
dnsPort, err := strconv.ParseUint(dnsPortStr, 10, 16)
if err != nil { if err != nil {
err = fmt.Errorf("DNS server must be enable") err = fmt.Errorf("DNS server must be enable")
return return
@ -337,7 +329,7 @@ func updateIPTables(cfg *config.Config) {
dialer.DefaultRoutingMark.Store(2158) dialer.DefaultRoutingMark.Store(2158)
} }
err = tproxy.SetTProxyIPTables(inboundInterface, uint16(tProxyPort), uint16(dnsPort)) err = tproxy.SetTProxyIPTables(inboundInterface, uint16(tProxyPort), dnsPort.Port())
if err != nil { if err != nil {
return return
} }

View File

@ -51,14 +51,14 @@ func Start(addr string, secret string) {
r := chi.NewRouter() r := chi.NewRouter()
cors := cors.New(cors.Options{ corsM := cors.New(cors.Options{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE"}, AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE"},
AllowedHeaders: []string{"Content-Type", "Authorization"}, AllowedHeaders: []string{"Content-Type", "Authorization"},
MaxAge: 300, MaxAge: 300,
}) })
r.Use(cors.Handler) r.Use(corsM.Handler)
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(authentication) r.Use(authentication)
@ -209,24 +209,35 @@ func getLogs(w http.ResponseWriter, r *http.Request) {
render.Status(r, http.StatusOK) render.Status(r, http.StatusOK)
} }
ch := make(chan log.Event, 1024)
sub := log.Subscribe() sub := log.Subscribe()
defer log.UnSubscribe(sub) defer log.UnSubscribe(sub)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
var err error
for elm := range sub { go func() {
buf.Reset() for elm := range sub {
log := elm.(*log.Event) select {
if log.LogLevel < level { case ch <- elm:
default:
}
}
close(ch)
}()
for logM := range ch {
if logM.LogLevel < level {
continue continue
} }
buf.Reset()
if err := json.NewEncoder(buf).Encode(Log{ if err := json.NewEncoder(buf).Encode(Log{
Type: log.Type(), Type: logM.Type(),
Payload: log.Payload, Payload: logM.Payload,
}); err != nil { }); err != nil {
break break
} }
var err error
if wsConn == nil { if wsConn == nil {
_, err = w.Write(buf.Bytes()) _, err = w.Write(buf.Bytes())
w.(http.Flusher).Flush() w.(http.Flusher).Flush()

View File

@ -19,12 +19,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
client := newClient(c.RemoteAddr(), in) client := newClient(c.RemoteAddr(), in)
defer client.CloseIdleConnections() defer client.CloseIdleConnections()
var conn *N.BufferedConn conn := N.NewBufferedConn(c)
if bufConn, ok := c.(*N.BufferedConn); ok {
conn = bufConn
} else {
conn = N.NewBufferedConn(c)
}
keepAlive := true keepAlive := true
trusted := cache == nil // disable authenticate if cache is nil trusted := cache == nil // disable authenticate if cache is nil
@ -66,15 +61,23 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
request.RequestURI = "" request.RequestURI = ""
RemoveHopByHopHeaders(request.Header) if isUpgradeRequest(request) {
RemoveExtraHTTPHostPort(request) if resp = HandleUpgrade(conn, conn.RemoteAddr(), request, in); resp == nil {
return // hijack connection
}
}
if request.URL.Scheme == "" || request.URL.Host == "" { if resp == nil {
resp = responseWith(request, http.StatusBadRequest) RemoveHopByHopHeaders(request.Header)
} else { RemoveExtraHTTPHostPort(request)
resp, err = client.Do(request)
if err != nil { if request.URL.Scheme == "" || request.URL.Host == "" {
resp = responseWith(request, http.StatusBadGateway) resp = responseWith(request, http.StatusBadRequest)
} else {
resp, err = client.Do(request)
if err != nil {
resp = responseWith(request, http.StatusBadGateway)
}
} }
} }
@ -95,7 +98,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
} }
} }
conn.Close() _ = conn.Close()
} }
func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response { func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response {

96
listener/http/upgrade.go Normal file
View File

@ -0,0 +1,96 @@
package http
import (
"context"
"crypto/tls"
"net"
"net/http"
"strings"
"time"
"github.com/Dreamacro/clash/adapter/inbound"
N "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5"
)
func isUpgradeRequest(req *http.Request) bool {
return strings.EqualFold(req.Header.Get("Connection"), "Upgrade")
}
func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) {
removeProxyHeaders(request.Header)
RemoveExtraHTTPHostPort(request)
address := request.Host
if _, _, err := net.SplitHostPort(address); err != nil {
port := "80"
if request.TLS != nil {
port = "443"
}
address = net.JoinHostPort(address, port)
}
dstAddr := socks5.ParseAddr(address)
if dstAddr == nil {
return
}
left, right := net.Pipe()
in <- inbound.NewMitm(dstAddr, source, request.Header.Get("User-Agent"), right)
var remoteServer *N.BufferedConn
if request.TLS != nil {
tlsConn := tls.Client(left, &tls.Config{
ServerName: request.URL.Hostname(),
})
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
defer cancel()
if tlsConn.HandshakeContext(ctx) != nil {
_ = localConn.Close()
_ = left.Close()
return
}
remoteServer = N.NewBufferedConn(tlsConn)
} else {
remoteServer = N.NewBufferedConn(left)
}
defer func() {
_ = remoteServer.Close()
}()
err := request.Write(remoteServer)
if err != nil {
_ = localConn.Close()
return
}
resp, err = http.ReadResponse(remoteServer.Reader(), request)
if err != nil {
_ = localConn.Close()
return
}
if resp.StatusCode == http.StatusSwitchingProtocols {
removeProxyHeaders(resp.Header)
err = localConn.SetReadDeadline(time.Time{}) // set to not time out
if err != nil {
return
}
err = resp.Write(localConn)
if err != nil {
return
}
N.Relay(remoteServer, localConn) // blocking here
_ = localConn.Close()
resp = nil
}
return
}

View File

@ -8,15 +8,21 @@ import (
"strings" "strings"
) )
// removeProxyHeaders remove Proxy-* headers
func removeProxyHeaders(header http.Header) {
header.Del("Proxy-Connection")
header.Del("Proxy-Authenticate")
header.Del("Proxy-Authorization")
}
// RemoveHopByHopHeaders remove hop-by-hop header // RemoveHopByHopHeaders remove hop-by-hop header
func RemoveHopByHopHeaders(header http.Header) { func RemoveHopByHopHeaders(header http.Header) {
// Strip hop-by-hop header based on RFC: // Strip hop-by-hop header based on RFC:
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
// https://www.mnot.net/blog/2011/07/11/what_proxies_must_do // https://www.mnot.net/blog/2011/07/11/what_proxies_must_do
header.Del("Proxy-Connection") removeProxyHeaders(header)
header.Del("Proxy-Authenticate")
header.Del("Proxy-Authorization")
header.Del("TE") header.Del("TE")
header.Del("Trailers") header.Del("Trailers")
header.Del("Transfer-Encoding") header.Del("Transfer-Encoding")

View File

@ -397,13 +397,12 @@ func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) {
certOption, err = cert.NewConfig( certOption, err = cert.NewConfig(
x509c, x509c,
privateKey, privateKey,
cert.NewAutoGCCertsStorage(),
) )
if err != nil { if err != nil {
return return
} }
certOption.SetValidity(time.Hour * 24 * 90) certOption.SetValidity(time.Hour * 24 * 365 * 2) // 2 years
certOption.SetOrganization("Clash ManInTheMiddle Proxy Services") certOption.SetOrganization("Clash ManInTheMiddle Proxy Services")
opt := &mitm.Option{ opt := &mitm.Option{

View File

@ -18,9 +18,11 @@ func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http
Transport: &http.Transport{ Transport: &http.Transport{
// excepted HTTP/2 // excepted HTTP/2
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
// from http.DefaultTransport // only needed 1 connection
MaxIdleConns: 100, MaxIdleConns: 1,
IdleConnTimeout: 90 * time.Second, MaxIdleConnsPerHost: 1,
MaxConnsPerHost: 1,
IdleConnTimeout: 60 * time.Second,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{

View File

@ -44,13 +44,13 @@ startOver:
readLoop: readLoop:
for { for {
// use SetReadDeadline instead of Proxy-Connection keep-alive // use SetReadDeadline instead of Proxy-Connection keep-alive
if err := conn.SetReadDeadline(time.Now().Add(95 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(65 * time.Second)); err != nil {
break readLoop break
} }
request, err := H.ReadRequest(conn.Reader()) request, err := H.ReadRequest(conn.Reader())
if err != nil { if err != nil {
break readLoop break
} }
var response *http.Response var response *http.Response
@ -71,7 +71,7 @@ readLoop:
// Manual writing to support CONNECT for http 1.0 (workaround for uplay client) // Manual writing to support CONNECT for http 1.0 (workaround for uplay client)
if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil { if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop // close connection break // close connection
} }
if strings.HasSuffix(session.request.URL.Host, ":80") { if strings.HasSuffix(session.request.URL.Host, ":80") {
@ -81,18 +81,18 @@ readLoop:
b, err := conn.Peek(1) b, err := conn.Peek(1)
if err != nil { if err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop // close connection break // close connection
} }
// TLS handshake. // TLS handshake.
if b[0] == 0x16 { if b[0] == 0x16 {
tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Hostname()))
// Handshake with the local client // Handshake with the local client
if err = tlsConn.Handshake(); err != nil { if err = tlsConn.Handshake(); err != nil {
session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err)) session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err))
_ = writeResponse(session, false) _ = writeResponse(session, false)
break readLoop // close connection break // close connection
} }
c = tlsConn c = tlsConn
@ -105,20 +105,27 @@ readLoop:
prepareRequest(c, session.request) prepareRequest(c, session.request)
H.RemoveHopByHopHeaders(session.request.Header)
H.RemoveExtraHTTPHostPort(session.request)
// hijack api // hijack api
if session.request.URL.Hostname() == opt.ApiHost { if session.request.URL.Hostname() == opt.ApiHost {
if err = handleApiRequest(session, opt); err != nil { if err = handleApiRequest(session, opt); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop
} }
return break
} }
// forward websocket
if isWebsocketRequest(request) {
session.request.RequestURI = ""
if session.response = H.HandleUpgrade(conn, source, request, in); session.response == nil {
return // hijack connection
}
}
H.RemoveHopByHopHeaders(session.request.Header)
H.RemoveExtraHTTPHostPort(session.request)
// hijack custom request and write back custom response if necessary // hijack custom request and write back custom response if necessary
if opt.Handler != nil { if opt.Handler != nil && session.response == nil {
newReq, newRes := opt.Handler.HandleRequest(session) newReq, newRes := opt.Handler.HandleRequest(session)
if newReq != nil { if newReq != nil {
session.request = newReq session.request = newReq
@ -128,28 +135,30 @@ readLoop:
if err = writeResponse(session, false); err != nil { if err = writeResponse(session, false); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop break
} }
return continue
} }
} }
session.request.RequestURI = "" if session.response == nil {
session.request.RequestURI = ""
if session.request.URL.Host == "" { if session.request.URL.Host == "" {
session.response = session.NewErrorResponse(ErrInvalidURL) session.response = session.NewErrorResponse(ErrInvalidURL)
} else { } else {
client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in)
// send the request to remote server // send the request to remote server
session.response, err = client.Do(session.request) session.response, err = client.Do(session.request)
if err != nil { if err != nil {
handleError(opt, session, err) handleError(opt, session, err)
session.response = session.NewErrorResponse(err) session.response = session.NewErrorResponse(fmt.Errorf("request failed: %w", err))
if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") {
_ = writeResponse(session, false) _ = writeResponse(session, false)
break readLoop break
}
} }
} }
} }
@ -157,7 +166,7 @@ readLoop:
if err = writeResponseWithHandler(session, opt); err != nil { if err = writeResponseWithHandler(session, opt); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop // close connection break // close connection
} }
} }
@ -167,13 +176,7 @@ readLoop:
func writeResponseWithHandler(session *Session, opt *Option) error { func writeResponseWithHandler(session *Session, opt *Option) error {
if opt.Handler != nil { if opt.Handler != nil {
res := opt.Handler.HandleResponse(session) res := opt.Handler.HandleResponse(session)
if res != nil { if res != nil {
body := res.Body
defer func(body io.ReadCloser) {
_ = body.Close()
}(body)
session.response = res session.response = res
} }
} }
@ -186,7 +189,7 @@ func writeResponse(session *Session, keepAlive bool) error {
if keepAlive { if keepAlive {
session.response.Header.Set("Connection", "keep-alive") session.response.Header.Set("Connection", "keep-alive")
session.response.Header.Set("Keep-Alive", "timeout=90") session.response.Header.Set("Keep-Alive", "timeout=60")
} }
return session.writeResponse() return session.writeResponse()
@ -201,10 +204,6 @@ func handleApiRequest(session *Session, opt *Option) error {
session.response = session.NewResponse(http.StatusOK, bytes.NewReader(b)) session.response = session.NewResponse(http.StatusOK, bytes.NewReader(b))
defer func(body io.ReadCloser) {
_ = body.Close()
}(session.response.Body)
session.response.Close = true session.response.Close = true
session.response.Header.Set("Content-Type", "application/x-x509-ca-cert") session.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
session.response.ContentLength = int64(len(b)) session.response.ContentLength = int64(len(b))
@ -230,11 +229,6 @@ func handleApiRequest(session *Session, opt *Option) error {
b = fmt.Sprintf(b, session.request.URL.Path) b = fmt.Sprintf(b, session.request.URL.Path)
session.response = session.NewResponse(http.StatusNotFound, bytes.NewReader([]byte(b))) session.response = session.NewResponse(http.StatusNotFound, bytes.NewReader([]byte(b)))
defer func(body io.ReadCloser) {
_ = body.Close()
}(session.response.Body)
session.response.Close = true session.response.Close = true
session.response.Header.Set("Content-Type", "text/html;charset=utf-8") session.response.Header.Set("Content-Type", "text/html;charset=utf-8")
session.response.ContentLength = int64(len(b)) session.response.ContentLength = int64(len(b))
@ -243,6 +237,12 @@ func handleApiRequest(session *Session, opt *Option) error {
} }
func handleError(opt *Option, session *Session, err error) { func handleError(opt *Option, session *Session, err error) {
if session.response != nil {
defer func() {
_, _ = io.Copy(io.Discard, session.response.Body)
_ = session.response.Body.Close()
}()
}
if opt.Handler != nil { if opt.Handler != nil {
opt.Handler.HandleError(session, err) opt.Handler.HandleError(session, err)
} }

View File

@ -43,6 +43,9 @@ func (s *Session) writeResponse() error {
if s.response == nil { if s.response == nil {
return ErrInvalidResponse return ErrInvalidResponse
} }
defer func(resp *http.Response) {
_ = resp.Body.Close()
}(s.response)
return s.response.Write(s.conn) return s.response.Write(s.conn)
} }

View File

@ -20,6 +20,10 @@ var (
ErrInvalidURL = errors.New("invalid URL") ErrInvalidURL = errors.New("invalid URL")
) )
func isWebsocketRequest(req *http.Request) bool {
return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket"
}
func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { func NewResponse(code int, body io.Reader, req *http.Request) *http.Response {
if body == nil { if body == nil {
body = &bytes.Buffer{} body = &bytes.Buffer{}

View File

@ -27,10 +27,8 @@ func StartListener(device io.ReadWriteCloser, gateway, portal, broadcast netip.A
} }
func (t *StackListener) Close() error { func (t *StackListener) Close() error {
_ = t.tcp.Close()
_ = t.udp.Close() _ = t.udp.Close()
return t.tcp.Close()
return t.device.Close()
} }
func (t *StackListener) TCP() *nat.TCP { func (t *StackListener) TCP() *nat.TCP {

View File

@ -1,13 +1,14 @@
package nat package nat
import ( import (
"container/list"
"net/netip" "net/netip"
"github.com/Dreamacro/clash/common/generics/list"
) )
const ( const (
portBegin = 30000 portBegin = 30000
portLength = 4096 portLength = 10240
) )
var zeroTuple = tuple{} var zeroTuple = tuple{}
@ -23,9 +24,9 @@ type binding struct {
} }
type table struct { type table struct {
tuples map[tuple]*list.Element tuples map[tuple]*list.Element[*binding]
ports [portLength]*list.Element ports [portLength]*list.Element[*binding]
available *list.List available *list.List[*binding]
} }
func (t *table) tupleOf(port uint16) tuple { func (t *table) tupleOf(port uint16) tuple {
@ -38,7 +39,7 @@ func (t *table) tupleOf(port uint16) tuple {
t.available.MoveToFront(elm) t.available.MoveToFront(elm)
return elm.Value.(*binding).tuple return elm.Value.tuple
} }
func (t *table) portOf(tuple tuple) uint16 { func (t *table) portOf(tuple tuple) uint16 {
@ -49,12 +50,12 @@ func (t *table) portOf(tuple tuple) uint16 {
t.available.MoveToFront(elm) t.available.MoveToFront(elm)
return portBegin + elm.Value.(*binding).offset return portBegin + elm.Value.offset
} }
func (t *table) newConn(tuple tuple) uint16 { func (t *table) newConn(tuple tuple) uint16 {
elm := t.available.Back() elm := t.available.Back()
b := elm.Value.(*binding) b := elm.Value
delete(t.tuples, b.tuple) delete(t.tuples, b.tuple)
t.tuples[tuple] = elm t.tuples[tuple] = elm
@ -67,9 +68,9 @@ func (t *table) newConn(tuple tuple) uint16 {
func newTable() *table { func newTable() *table {
result := &table{ result := &table{
tuples: make(map[tuple]*list.Element, portLength), tuples: make(map[tuple]*list.Element[*binding], portLength),
ports: [portLength]*list.Element{}, ports: [portLength]*list.Element[*binding]{},
available: list.New(), available: list.New[*binding](),
} }
for idx := range result.ports { for idx := range result.ports {

View File

@ -7,6 +7,7 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strconv" "strconv"
"sync"
"time" "time"
"github.com/Dreamacro/clash/adapter/inbound" "github.com/Dreamacro/clash/adapter/inbound"
@ -28,6 +29,8 @@ type sysStack struct {
device device.Device device device.Device
closed bool closed bool
once sync.Once
wg sync.WaitGroup
} }
func (s *sysStack) Close() error { func (s *sysStack) Close() error {
@ -38,10 +41,12 @@ func (s *sysStack) Close() error {
}() }()
s.closed = true s.closed = true
if s.stack != nil {
return s.stack.Close() err := s.stack.Close()
}
return nil s.wg.Wait()
return err
} }
func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) { func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) {
@ -67,16 +72,10 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
_ = tcp.Close() _ = tcp.Close()
}(stack.TCP()) }(stack.TCP())
defer log.Debugln("TCP: closed")
for !ipStack.closed { for !ipStack.closed {
if err = stack.TCP().SetDeadline(time.Time{}); err != nil {
break
}
conn, err := stack.TCP().Accept() conn, err := stack.TCP().Accept()
if err != nil { if err != nil {
log.Debugln("Accept connection: %v", err) log.Debugln("[STACK] accept connection error: %v", err)
continue continue
} }
@ -146,6 +145,8 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
tcpIn <- context.NewConnContext(conn, metadata) tcpIn <- context.NewConnContext(conn, metadata)
} }
ipStack.wg.Done()
} }
udp := func() { udp := func() {
@ -153,14 +154,13 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
_ = udp.Close() _ = udp.Close()
}(stack.UDP()) }(stack.UDP())
defer log.Debugln("UDP: closed")
for !ipStack.closed { for !ipStack.closed {
buf := pool.Get(pool.UDPBufferSize) buf := pool.Get(pool.UDPBufferSize)
n, lRAddr, rRAddr, err := stack.UDP().ReadFrom(buf) n, lRAddr, rRAddr, err := stack.UDP().ReadFrom(buf)
if err != nil { if err != nil {
return _ = pool.Put(buf)
break
} }
raw := buf[:n] raw := buf[:n]
@ -209,17 +209,23 @@ func New(device device.Device, dnsHijack []netip.AddrPort, tunAddress netip.Pref
default: default:
} }
} }
ipStack.wg.Done()
} }
go tcp() ipStack.once.Do(func() {
ipStack.wg.Add(1)
go tcp()
numUDPWorkers := 4 numUDPWorkers := 4
if num := runtime.GOMAXPROCS(0); num > numUDPWorkers { if num := runtime.GOMAXPROCS(0); num > numUDPWorkers {
numUDPWorkers = num numUDPWorkers = num
} }
for i := 0; i < numUDPWorkers; i++ { for i := 0; i < numUDPWorkers; i++ {
go udp() ipStack.wg.Add(1)
} go udp()
}
})
return ipStack, nil return ipStack, nil
} }

View File

@ -145,6 +145,7 @@ func setAtLatest(stackType C.TUNStack, devName string) {
case "darwin": case "darwin":
// _, _ = cmd.ExecCmd("sysctl -w net.inet.ip.forwarding=1") // _, _ = cmd.ExecCmd("sysctl -w net.inet.ip.forwarding=1")
// _, _ = cmd.ExecCmd("sysctl -w net.inet6.ip6.forwarding=1") // _, _ = cmd.ExecCmd("sysctl -w net.inet6.ip6.forwarding=1")
_, _ = cmd.ExecCmd("sudo launchctl limit maxfiles 10240 unlimited")
case "windows": case "windows":
_, _ = cmd.ExecCmd("ipconfig /renew") _, _ = cmd.ExecCmd("ipconfig /renew")
case "linux": case "linux":

View File

@ -10,8 +10,8 @@ import (
) )
var ( var (
logCh = make(chan any) logCh = make(chan Event)
source = observable.NewObservable(logCh) source = observable.NewObservable[Event](logCh)
level = INFO level = INFO
) )
@ -25,7 +25,7 @@ type Event struct {
Payload string Payload string
} }
func (e *Event) Type() string { func (e Event) Type() string {
return e.LogLevel.String() return e.LogLevel.String()
} }
@ -57,12 +57,12 @@ func Fatalln(format string, v ...any) {
log.Fatalf(format, v...) log.Fatalf(format, v...)
} }
func Subscribe() observable.Subscription { func Subscribe() observable.Subscription[Event] {
sub, _ := source.Subscribe() sub, _ := source.Subscribe()
return sub return sub
} }
func UnSubscribe(sub observable.Subscription) { func UnSubscribe(sub observable.Subscription[Event]) {
source.UnSubscribe(sub) source.UnSubscribe(sub)
} }
@ -74,7 +74,7 @@ func SetLevel(newLevel LogLevel) {
level = newLevel level = newLevel
} }
func print(data *Event) { func print(data Event) {
if data.LogLevel < level { if data.LogLevel < level {
return return
} }
@ -91,8 +91,8 @@ func print(data *Event) {
} }
} }
func newLog(logLevel LogLevel, format string, v ...any) *Event { func newLog(logLevel LogLevel, format string, v ...any) Event {
return &Event{ return Event{
LogLevel: logLevel, LogLevel: logLevel,
Payload: fmt.Sprintf(format, v...), Payload: fmt.Sprintf(format, v...),
} }

View File

@ -8,7 +8,7 @@ require (
github.com/docker/go-connections v0.4.0 github.com/docker/go-connections v0.4.0
github.com/miekg/dns v1.1.48 github.com/miekg/dns v1.1.48
github.com/stretchr/testify v1.7.1 github.com/stretchr/testify v1.7.1
golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 golang.org/x/net v0.0.0-20220421235706-1d1ef9303861
) )
replace github.com/Dreamacro/clash => ../ replace github.com/Dreamacro/clash => ../
@ -42,7 +42,7 @@ require (
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 // indirect
golang.org/x/text v0.3.8-0.20220124021120-d1c84af989ab // indirect golang.org/x/text v0.3.8-0.20220124021120-d1c84af989ab // indirect
golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect
golang.org/x/tools v0.1.9 // indirect golang.org/x/tools v0.1.9 // indirect
@ -56,5 +56,5 @@ require (
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
gotest.tools/v3 v3.1.0 // indirect gotest.tools/v3 v3.1.0 // indirect
gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b // indirect gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4 // indirect
) )

View File

@ -1012,8 +1012,8 @@ golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20210825183410-e898025ed96a/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210825183410-e898025ed96a/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 h1:6mzvA99KwZxbOrxww4EvWVQUnN1+xEu9tafK5ZxkYeA= golang.org/x/net v0.0.0-20220421235706-1d1ef9303861 h1:yssD99+7tqHWO5Gwh81phT+67hg+KttniBr6UnEXOY8=
golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220421235706-1d1ef9303861/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -1143,8 +1143,8 @@ golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16CMAGuqwO2lX1mTyyRRc=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
@ -1413,8 +1413,8 @@ gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk=
gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8=
gotest.tools/v3 v3.1.0 h1:rVV8Tcg/8jHUkPUorwjaMTtemIMVXfIPKiOqnhEhakk= gotest.tools/v3 v3.1.0 h1:rVV8Tcg/8jHUkPUorwjaMTtemIMVXfIPKiOqnhEhakk=
gotest.tools/v3 v3.1.0/go.mod h1:fHy7eyTmJFO5bQbUsEGQ1v4m2J3Jz9eWL54TP2/ZuYQ= gotest.tools/v3 v3.1.0/go.mod h1:fHy7eyTmJFO5bQbUsEGQ1v4m2J3Jz9eWL54TP2/ZuYQ=
gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b h1:zBJp2eKSoNIV6+9LO3bRhlnuK280Oyrwc6OeFIN6VzU= gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4 h1:CSkd548jw5hmVwdJ+JuUhMtRV56oQBER7sbkIOePP2Y=
gvisor.dev/gvisor v0.0.0-20220419020849-1f2f4462d45b/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI= gvisor.dev/gvisor v0.0.0-20220422224113-2cca6b79d9f4/go.mod h1:tWwEcFvJavs154OdjFCw78axNrsDlz4Zh8jvPqwcpGI=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@ -11,7 +11,7 @@ import (
) )
type Pool struct { type Pool struct {
pool *pool.Pool pool *pool.Pool[*Snell]
} }
func (p *Pool) Get() (net.Conn, error) { func (p *Pool) Get() (net.Conn, error) {
@ -24,12 +24,12 @@ func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) {
return nil, err return nil, err
} }
return &PoolConn{elm.(*Snell), p}, nil return &PoolConn{elm, p}, nil
} }
func (p *Pool) Put(conn net.Conn) { func (p *Pool) Put(conn *Snell) {
if err := HalfClose(conn); err != nil { if err := HalfClose(conn); err != nil {
conn.Close() _ = conn.Close()
return return
} }
@ -64,22 +64,22 @@ func (pc *PoolConn) Write(b []byte) (int, error) {
func (pc *PoolConn) Close() error { func (pc *PoolConn) Close() error {
// clash use SetReadDeadline to break bidirectional copy between client and server. // clash use SetReadDeadline to break bidirectional copy between client and server.
// reset it before reuse connection to avoid io timeout error. // reset it before reuse connection to avoid io timeout error.
pc.Snell.Conn.SetReadDeadline(time.Time{}) _ = pc.Snell.Conn.SetReadDeadline(time.Time{})
pc.pool.Put(pc.Snell) pc.pool.Put(pc.Snell)
return nil return nil
} }
func NewPool(factory func(context.Context) (*Snell, error)) *Pool { func NewPool(factory func(context.Context) (*Snell, error)) *Pool {
p := pool.New( p := pool.New[*Snell](
func(ctx context.Context) (any, error) { func(ctx context.Context) (*Snell, error) {
return factory(ctx) return factory(ctx)
}, },
pool.WithAge(15000), pool.WithAge[*Snell](15000),
pool.WithSize(10), pool.WithSize[*Snell](10),
pool.WithEvict(func(item any) { pool.WithEvict[*Snell](func(item *Snell) {
item.(*Snell).Close() _ = item.Close()
}), }),
) )
return &Pool{p} return &Pool{pool: p}
} }

View File

@ -2,7 +2,6 @@ package tunnel
import ( import (
"errors" "errors"
"io"
"net" "net"
"time" "time"
@ -63,35 +62,5 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, fAddr n
} }
func handleSocket(ctx C.ConnContext, outbound net.Conn) { func handleSocket(ctx C.ConnContext, outbound net.Conn) {
relay(ctx.Conn(), outbound) N.Relay(ctx.Conn(), outbound)
}
// relay copies between left and right bidirectionally.
func relay(leftConn, rightConn net.Conn) {
ch := make(chan error)
tcpKeepAlive(leftConn)
tcpKeepAlive(rightConn)
go func() {
buf := pool.Get(pool.RelayBufferSize)
// Wrapping to avoid using *net.TCPConn.(ReadFrom)
// See also https://github.com/Dreamacro/clash/pull/1209
_, err := io.CopyBuffer(N.WriteOnlyWriter{Writer: leftConn}, N.ReadOnlyReader{Reader: rightConn}, buf)
pool.Put(buf)
leftConn.SetReadDeadline(time.Now())
ch <- err
}()
buf := pool.Get(pool.RelayBufferSize)
io.CopyBuffer(N.WriteOnlyWriter{Writer: rightConn}, N.ReadOnlyReader{Reader: leftConn}, buf)
pool.Put(buf)
rightConn.SetReadDeadline(time.Now())
<-ch
}
func tcpKeepAlive(c net.Conn) {
if tcp, ok := c.(*net.TCPConn); ok {
tcp.SetKeepAlive(true)
}
} }