Fix: tunnel manager & tracker race condition (#1048)

This commit is contained in:
Jason Lyu 2020-10-29 17:51:14 +08:00 committed by GitHub
parent b98e9ea202
commit 87e4d94290
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 82 additions and 68 deletions

View File

@ -6,11 +6,12 @@ import (
"errors" "errors"
"net" "net"
"net/http" "net/http"
"sync/atomic"
"time" "time"
"github.com/Dreamacro/clash/common/queue" "github.com/Dreamacro/clash/common/queue"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"go.uber.org/atomic"
) )
type Base struct { type Base struct {
@ -95,11 +96,11 @@ func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn {
type Proxy struct { type Proxy struct {
C.ProxyAdapter C.ProxyAdapter
history *queue.Queue history *queue.Queue
alive uint32 alive *atomic.Bool
} }
func (p *Proxy) Alive() bool { func (p *Proxy) Alive() bool {
return atomic.LoadUint32(&p.alive) > 0 return p.alive.Load()
} }
func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) { func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
@ -111,7 +112,7 @@ func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
conn, err := p.ProxyAdapter.DialContext(ctx, metadata) conn, err := p.ProxyAdapter.DialContext(ctx, metadata)
if err != nil { if err != nil {
atomic.StoreUint32(&p.alive, 0) p.alive.Store(false)
} }
return conn, err return conn, err
} }
@ -128,7 +129,7 @@ func (p *Proxy) DelayHistory() []C.DelayHistory {
// LastDelay return last history record. if proxy is not alive, return the max value of uint16. // LastDelay return last history record. if proxy is not alive, return the max value of uint16.
func (p *Proxy) LastDelay() (delay uint16) { func (p *Proxy) LastDelay() (delay uint16) {
var max uint16 = 0xffff var max uint16 = 0xffff
if atomic.LoadUint32(&p.alive) == 0 { if !p.alive.Load() {
return max return max
} }
@ -159,11 +160,7 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
// URLTest get the delay for the specified URL // URLTest get the delay for the specified URL
func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) { func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
defer func() { defer func() {
if err == nil { p.alive.Store(err == nil)
atomic.StoreUint32(&p.alive, 1)
} else {
atomic.StoreUint32(&p.alive, 0)
}
record := C.DelayHistory{Time: time.Now()} record := C.DelayHistory{Time: time.Now()}
if err == nil { if err == nil {
record.Delay = t record.Delay = t
@ -219,5 +216,5 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
} }
func NewProxy(adapter C.ProxyAdapter) *Proxy { func NewProxy(adapter C.ProxyAdapter) *Proxy {
return &Proxy{adapter, queue.New(10), 1} return &Proxy{adapter, queue.New(10), atomic.NewBool(true)}
} }

View File

@ -2,11 +2,11 @@ package observable
import ( import (
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/atomic"
) )
func iterator(item []interface{}) chan interface{} { func iterator(item []interface{}) chan interface{} {
@ -33,25 +33,25 @@ func TestObservable(t *testing.T) {
assert.Equal(t, count, 5) assert.Equal(t, count, 5)
} }
func TestObservable_MutilSubscribe(t *testing.T) { func TestObservable_MultiSubscribe(t *testing.T) {
iter := iterator([]interface{}{1, 2, 3, 4, 5}) iter := iterator([]interface{}{1, 2, 3, 4, 5})
src := NewObservable(iter) src := NewObservable(iter)
ch1, _ := src.Subscribe() ch1, _ := src.Subscribe()
ch2, _ := src.Subscribe() ch2, _ := src.Subscribe()
var count int32 var count = atomic.NewInt32(0)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
waitCh := func(ch <-chan interface{}) { waitCh := func(ch <-chan interface{}) {
for range ch { for range ch {
atomic.AddInt32(&count, 1) count.Inc()
} }
wg.Done() wg.Done()
} }
go waitCh(ch1) go waitCh(ch1)
go waitCh(ch2) go waitCh(ch2)
wg.Wait() wg.Wait()
assert.Equal(t, int32(10), count) assert.Equal(t, int32(10), count.Load())
} }
func TestObservable_UnSubscribe(t *testing.T) { func TestObservable_UnSubscribe(t *testing.T) {

View File

@ -2,17 +2,17 @@ package singledo
import ( import (
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/atomic"
) )
func TestBasic(t *testing.T) { func TestBasic(t *testing.T) {
single := NewSingle(time.Millisecond * 30) single := NewSingle(time.Millisecond * 30)
foo := 0 foo := 0
var shardCount int32 = 0 var shardCount = atomic.NewInt32(0)
call := func() (interface{}, error) { call := func() (interface{}, error) {
foo++ foo++
time.Sleep(time.Millisecond * 5) time.Sleep(time.Millisecond * 5)
@ -26,7 +26,7 @@ func TestBasic(t *testing.T) {
go func() { go func() {
_, _, shard := single.Do(call) _, _, shard := single.Do(call)
if shard { if shard {
atomic.AddInt32(&shardCount, 1) shardCount.Inc()
} }
wg.Done() wg.Done()
}() }()
@ -34,7 +34,7 @@ func TestBasic(t *testing.T) {
wg.Wait() wg.Wait()
assert.Equal(t, 1, foo) assert.Equal(t, 1, foo)
assert.Equal(t, int32(4), shardCount) assert.Equal(t, int32(4), shardCount.Load())
} }
func TestTimer(t *testing.T) { func TestTimer(t *testing.T) {

1
go.mod
View File

@ -13,6 +13,7 @@ require (
github.com/oschwald/geoip2-golang v1.4.0 github.com/oschwald/geoip2-golang v1.4.0
github.com/sirupsen/logrus v1.7.0 github.com/sirupsen/logrus v1.7.0
github.com/stretchr/testify v1.6.1 github.com/stretchr/testify v1.6.1
go.uber.org/atomic v1.7.0
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897
golang.org/x/net v0.0.0-20201020065357-d65d470038a5 golang.org/x/net v0.0.0-20201020065357-d65d470038a5
golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520 golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520

3
go.sum
View File

@ -25,9 +25,12 @@ github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=

View File

@ -2,26 +2,34 @@ package tunnel
import ( import (
"sync" "sync"
"sync/atomic"
"time" "time"
"go.uber.org/atomic"
) )
var DefaultManager *Manager var DefaultManager *Manager
func init() { func init() {
DefaultManager = &Manager{} DefaultManager = &Manager{
uploadTemp: atomic.NewInt64(0),
downloadTemp: atomic.NewInt64(0),
uploadBlip: atomic.NewInt64(0),
downloadBlip: atomic.NewInt64(0),
uploadTotal: atomic.NewInt64(0),
downloadTotal: atomic.NewInt64(0),
}
go DefaultManager.handle() go DefaultManager.handle()
} }
type Manager struct { type Manager struct {
connections sync.Map connections sync.Map
uploadTemp int64 uploadTemp *atomic.Int64
downloadTemp int64 downloadTemp *atomic.Int64
uploadBlip int64 uploadBlip *atomic.Int64
downloadBlip int64 downloadBlip *atomic.Int64
uploadTotal int64 uploadTotal *atomic.Int64
downloadTotal int64 downloadTotal *atomic.Int64
} }
func (m *Manager) Join(c tracker) { func (m *Manager) Join(c tracker) {
@ -33,17 +41,17 @@ func (m *Manager) Leave(c tracker) {
} }
func (m *Manager) PushUploaded(size int64) { func (m *Manager) PushUploaded(size int64) {
atomic.AddInt64(&m.uploadTemp, size) m.uploadTemp.Add(size)
atomic.AddInt64(&m.uploadTotal, size) m.uploadTotal.Add(size)
} }
func (m *Manager) PushDownloaded(size int64) { func (m *Manager) PushDownloaded(size int64) {
atomic.AddInt64(&m.downloadTemp, size) m.downloadTemp.Add(size)
atomic.AddInt64(&m.downloadTotal, size) m.downloadTotal.Add(size)
} }
func (m *Manager) Now() (up int64, down int64) { func (m *Manager) Now() (up int64, down int64) {
return atomic.LoadInt64(&m.uploadBlip), atomic.LoadInt64(&m.downloadBlip) return m.uploadBlip.Load(), m.downloadBlip.Load()
} }
func (m *Manager) Snapshot() *Snapshot { func (m *Manager) Snapshot() *Snapshot {
@ -54,29 +62,29 @@ func (m *Manager) Snapshot() *Snapshot {
}) })
return &Snapshot{ return &Snapshot{
UploadTotal: atomic.LoadInt64(&m.uploadTotal), UploadTotal: m.uploadTotal.Load(),
DownloadTotal: atomic.LoadInt64(&m.downloadTotal), DownloadTotal: m.downloadTotal.Load(),
Connections: connections, Connections: connections,
} }
} }
func (m *Manager) ResetStatistic() { func (m *Manager) ResetStatistic() {
m.uploadTemp = 0 m.uploadTemp.Store(0)
m.uploadBlip = 0 m.uploadBlip.Store(0)
m.uploadTotal = 0 m.uploadTotal.Store(0)
m.downloadTemp = 0 m.downloadTemp.Store(0)
m.downloadBlip = 0 m.downloadBlip.Store(0)
m.downloadTotal = 0 m.downloadTotal.Store(0)
} }
func (m *Manager) handle() { func (m *Manager) handle() {
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
for range ticker.C { for range ticker.C {
atomic.StoreInt64(&m.uploadBlip, atomic.LoadInt64(&m.uploadTemp)) m.uploadBlip.Store(m.uploadTemp.Load())
atomic.StoreInt64(&m.uploadTemp, 0) m.uploadTemp.Store(0)
atomic.StoreInt64(&m.downloadBlip, atomic.LoadInt64(&m.downloadTemp)) m.downloadBlip.Store(m.downloadTemp.Load())
atomic.StoreInt64(&m.downloadTemp, 0) m.downloadTemp.Store(0)
} }
} }

View File

@ -2,11 +2,12 @@ package tunnel
import ( import (
"net" "net"
"sync/atomic"
"time" "time"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
"go.uber.org/atomic"
) )
type tracker interface { type tracker interface {
@ -15,14 +16,14 @@ type tracker interface {
} }
type trackerInfo struct { type trackerInfo struct {
UUID uuid.UUID `json:"id"` UUID uuid.UUID `json:"id"`
Metadata *C.Metadata `json:"metadata"` Metadata *C.Metadata `json:"metadata"`
UploadTotal int64 `json:"upload"` UploadTotal *atomic.Int64 `json:"upload"`
DownloadTotal int64 `json:"download"` DownloadTotal *atomic.Int64 `json:"download"`
Start time.Time `json:"start"` Start time.Time `json:"start"`
Chain C.Chain `json:"chains"` Chain C.Chain `json:"chains"`
Rule string `json:"rule"` Rule string `json:"rule"`
RulePayload string `json:"rulePayload"` RulePayload string `json:"rulePayload"`
} }
type tcpTracker struct { type tcpTracker struct {
@ -39,7 +40,7 @@ func (tt *tcpTracker) Read(b []byte) (int, error) {
n, err := tt.Conn.Read(b) n, err := tt.Conn.Read(b)
download := int64(n) download := int64(n)
tt.manager.PushDownloaded(download) tt.manager.PushDownloaded(download)
atomic.AddInt64(&tt.DownloadTotal, download) tt.DownloadTotal.Add(download)
return n, err return n, err
} }
@ -47,7 +48,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) {
n, err := tt.Conn.Write(b) n, err := tt.Conn.Write(b)
upload := int64(n) upload := int64(n)
tt.manager.PushUploaded(upload) tt.manager.PushUploaded(upload)
atomic.AddInt64(&tt.UploadTotal, upload) tt.UploadTotal.Add(upload)
return n, err return n, err
} }
@ -63,11 +64,13 @@ func newTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R
Conn: conn, Conn: conn,
manager: manager, manager: manager,
trackerInfo: &trackerInfo{ trackerInfo: &trackerInfo{
UUID: uuid, UUID: uuid,
Start: time.Now(), Start: time.Now(),
Metadata: metadata, Metadata: metadata,
Chain: conn.Chains(), Chain: conn.Chains(),
Rule: "", Rule: "",
UploadTotal: atomic.NewInt64(0),
DownloadTotal: atomic.NewInt64(0),
}, },
} }
@ -94,7 +97,7 @@ func (ut *udpTracker) ReadFrom(b []byte) (int, net.Addr, error) {
n, addr, err := ut.PacketConn.ReadFrom(b) n, addr, err := ut.PacketConn.ReadFrom(b)
download := int64(n) download := int64(n)
ut.manager.PushDownloaded(download) ut.manager.PushDownloaded(download)
atomic.AddInt64(&ut.DownloadTotal, download) ut.DownloadTotal.Add(download)
return n, addr, err return n, addr, err
} }
@ -102,7 +105,7 @@ func (ut *udpTracker) WriteTo(b []byte, addr net.Addr) (int, error) {
n, err := ut.PacketConn.WriteTo(b, addr) n, err := ut.PacketConn.WriteTo(b, addr)
upload := int64(n) upload := int64(n)
ut.manager.PushUploaded(upload) ut.manager.PushUploaded(upload)
atomic.AddInt64(&ut.UploadTotal, upload) ut.UploadTotal.Add(upload)
return n, err return n, err
} }
@ -118,11 +121,13 @@ func newUDPTracker(conn C.PacketConn, manager *Manager, metadata *C.Metadata, ru
PacketConn: conn, PacketConn: conn,
manager: manager, manager: manager,
trackerInfo: &trackerInfo{ trackerInfo: &trackerInfo{
UUID: uuid, UUID: uuid,
Start: time.Now(), Start: time.Now(),
Metadata: metadata, Metadata: metadata,
Chain: conn.Chains(), Chain: conn.Chains(),
Rule: "", Rule: "",
UploadTotal: atomic.NewInt64(0),
DownloadTotal: atomic.NewInt64(0),
}, },
} }