Improve(fakeip): use lru cache to avoid outdate

This commit is contained in:
Dreamacro
2019-07-26 19:09:13 +08:00
parent 1702e7ddb4
commit 271ed2b9c1
9 changed files with 386 additions and 67 deletions

View File

@ -4,22 +4,72 @@ import (
"errors"
"net"
"sync"
"github.com/Dreamacro/clash/common/cache"
)
// Pool is a implementation about fake ip generator without storage
type Pool struct {
max uint32
min uint32
offset uint32
mux *sync.Mutex
max uint32
min uint32
gateway uint32
offset uint32
mux *sync.Mutex
cache *cache.LruCache
}
// Get return a new fake ip
func (p *Pool) Get() net.IP {
// Lookup return a fake ip with host
func (p *Pool) Lookup(host string) net.IP {
p.mux.Lock()
defer p.mux.Unlock()
ip := uintToIP(p.min + p.offset)
p.offset = (p.offset + 1) % (p.max - p.min)
if ip, exist := p.cache.Get(host); exist {
return ip.(net.IP)
}
ip := p.get(host)
p.cache.Set(host, ip)
return ip
}
// LookBack return host with the fake ip
func (p *Pool) LookBack(ip net.IP) (string, bool) {
p.mux.Lock()
defer p.mux.Unlock()
if ip = ip.To4(); ip == nil {
return "", false
}
n := ipToUint(ip.To4())
offset := n - p.min + 1
if host, exist := p.cache.Get(offset); exist {
return host.(string), true
}
return "", false
}
// Gateway return gateway ip
func (p *Pool) Gateway() net.IP {
return uintToIP(p.gateway)
}
func (p *Pool) get(host string) net.IP {
current := p.offset
for {
p.offset = (p.offset + 1) % (p.max - p.min)
// Avoid infinite loops
if p.offset == current {
break
}
if _, exist := p.cache.Get(p.offset); !exist {
break
}
}
ip := uintToIP(p.min + p.offset - 1)
p.cache.Set(p.offset, host)
return ip
}
@ -36,8 +86,8 @@ func uintToIP(v uint32) net.IP {
}
// New return Pool instance
func New(ipnet *net.IPNet) (*Pool, error) {
min := ipToUint(ipnet.IP) + 1
func New(ipnet *net.IPNet, size int) (*Pool, error) {
min := ipToUint(ipnet.IP) + 2
ones, bits := ipnet.Mask.Size()
total := 1<<uint(bits-ones) - 2
@ -46,10 +96,12 @@ func New(ipnet *net.IPNet) (*Pool, error) {
return nil, errors.New("ipnet don't have valid ip")
}
max := min + uint32(total)
max := min + uint32(total) - 1
return &Pool{
min: min,
max: max,
mux: &sync.Mutex{},
min: min,
max: max,
gateway: min - 1,
mux: &sync.Mutex{},
cache: cache.NewLRUCache(cache.WithSize(size * 2)),
}, nil
}

View File

@ -3,42 +3,49 @@ package fakeip
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPool_Basic(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/30")
pool, _ := New(ipnet)
_, ipnet, _ := net.ParseCIDR("192.168.0.1/29")
pool, _ := New(ipnet, 10)
first := pool.Get()
last := pool.Get()
first := pool.Lookup("foo.com")
last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last)
if !first.Equal(net.IP{192, 168, 0, 1}) {
t.Error("should get right first ip, instead of", first.String())
}
if !last.Equal(net.IP{192, 168, 0, 2}) {
t.Error("should get right last ip, instead of", first.String())
}
assert.True(t, first.Equal(net.IP{192, 168, 0, 2}))
assert.True(t, last.Equal(net.IP{192, 168, 0, 3}))
assert.True(t, exist)
assert.Equal(t, bar, "bar.com")
}
func TestPool_Cycle(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/30")
pool, _ := New(ipnet)
pool, _ := New(ipnet, 10)
first := pool.Get()
pool.Get()
same := pool.Get()
first := pool.Lookup("foo.com")
same := pool.Lookup("baz.com")
if !first.Equal(same) {
t.Error("should return same ip", first.String())
}
assert.True(t, first.Equal(same))
}
func TestPool_MaxCacheSize(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/24")
pool, _ := New(ipnet, 2)
first := pool.Lookup("foo.com")
pool.Lookup("bar.com")
pool.Lookup("baz.com")
next := pool.Lookup("foo.com")
assert.False(t, first.Equal(next))
}
func TestPool_Error(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/31")
_, err := New(ipnet)
_, err := New(ipnet, 10)
if err == nil {
t.Error("should return err")
}
assert.Error(t, err)
}