Improve(fakeip): use lru cache to avoid outdate
This commit is contained in:
@ -4,22 +4,72 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/Dreamacro/clash/common/cache"
|
||||
)
|
||||
|
||||
// Pool is a implementation about fake ip generator without storage
|
||||
type Pool struct {
|
||||
max uint32
|
||||
min uint32
|
||||
offset uint32
|
||||
mux *sync.Mutex
|
||||
max uint32
|
||||
min uint32
|
||||
gateway uint32
|
||||
offset uint32
|
||||
mux *sync.Mutex
|
||||
cache *cache.LruCache
|
||||
}
|
||||
|
||||
// Get return a new fake ip
|
||||
func (p *Pool) Get() net.IP {
|
||||
// Lookup return a fake ip with host
|
||||
func (p *Pool) Lookup(host string) net.IP {
|
||||
p.mux.Lock()
|
||||
defer p.mux.Unlock()
|
||||
ip := uintToIP(p.min + p.offset)
|
||||
p.offset = (p.offset + 1) % (p.max - p.min)
|
||||
if ip, exist := p.cache.Get(host); exist {
|
||||
return ip.(net.IP)
|
||||
}
|
||||
|
||||
ip := p.get(host)
|
||||
p.cache.Set(host, ip)
|
||||
return ip
|
||||
}
|
||||
|
||||
// LookBack return host with the fake ip
|
||||
func (p *Pool) LookBack(ip net.IP) (string, bool) {
|
||||
p.mux.Lock()
|
||||
defer p.mux.Unlock()
|
||||
|
||||
if ip = ip.To4(); ip == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
n := ipToUint(ip.To4())
|
||||
offset := n - p.min + 1
|
||||
|
||||
if host, exist := p.cache.Get(offset); exist {
|
||||
return host.(string), true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Gateway return gateway ip
|
||||
func (p *Pool) Gateway() net.IP {
|
||||
return uintToIP(p.gateway)
|
||||
}
|
||||
|
||||
func (p *Pool) get(host string) net.IP {
|
||||
current := p.offset
|
||||
for {
|
||||
p.offset = (p.offset + 1) % (p.max - p.min)
|
||||
// Avoid infinite loops
|
||||
if p.offset == current {
|
||||
break
|
||||
}
|
||||
|
||||
if _, exist := p.cache.Get(p.offset); !exist {
|
||||
break
|
||||
}
|
||||
}
|
||||
ip := uintToIP(p.min + p.offset - 1)
|
||||
p.cache.Set(p.offset, host)
|
||||
return ip
|
||||
}
|
||||
|
||||
@ -36,8 +86,8 @@ func uintToIP(v uint32) net.IP {
|
||||
}
|
||||
|
||||
// New return Pool instance
|
||||
func New(ipnet *net.IPNet) (*Pool, error) {
|
||||
min := ipToUint(ipnet.IP) + 1
|
||||
func New(ipnet *net.IPNet, size int) (*Pool, error) {
|
||||
min := ipToUint(ipnet.IP) + 2
|
||||
|
||||
ones, bits := ipnet.Mask.Size()
|
||||
total := 1<<uint(bits-ones) - 2
|
||||
@ -46,10 +96,12 @@ func New(ipnet *net.IPNet) (*Pool, error) {
|
||||
return nil, errors.New("ipnet don't have valid ip")
|
||||
}
|
||||
|
||||
max := min + uint32(total)
|
||||
max := min + uint32(total) - 1
|
||||
return &Pool{
|
||||
min: min,
|
||||
max: max,
|
||||
mux: &sync.Mutex{},
|
||||
min: min,
|
||||
max: max,
|
||||
gateway: min - 1,
|
||||
mux: &sync.Mutex{},
|
||||
cache: cache.NewLRUCache(cache.WithSize(size * 2)),
|
||||
}, nil
|
||||
}
|
||||
|
@ -3,42 +3,49 @@ package fakeip
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPool_Basic(t *testing.T) {
|
||||
_, ipnet, _ := net.ParseCIDR("192.168.0.1/30")
|
||||
pool, _ := New(ipnet)
|
||||
_, ipnet, _ := net.ParseCIDR("192.168.0.1/29")
|
||||
pool, _ := New(ipnet, 10)
|
||||
|
||||
first := pool.Get()
|
||||
last := pool.Get()
|
||||
first := pool.Lookup("foo.com")
|
||||
last := pool.Lookup("bar.com")
|
||||
bar, exist := pool.LookBack(last)
|
||||
|
||||
if !first.Equal(net.IP{192, 168, 0, 1}) {
|
||||
t.Error("should get right first ip, instead of", first.String())
|
||||
}
|
||||
|
||||
if !last.Equal(net.IP{192, 168, 0, 2}) {
|
||||
t.Error("should get right last ip, instead of", first.String())
|
||||
}
|
||||
assert.True(t, first.Equal(net.IP{192, 168, 0, 2}))
|
||||
assert.True(t, last.Equal(net.IP{192, 168, 0, 3}))
|
||||
assert.True(t, exist)
|
||||
assert.Equal(t, bar, "bar.com")
|
||||
}
|
||||
|
||||
func TestPool_Cycle(t *testing.T) {
|
||||
_, ipnet, _ := net.ParseCIDR("192.168.0.1/30")
|
||||
pool, _ := New(ipnet)
|
||||
pool, _ := New(ipnet, 10)
|
||||
|
||||
first := pool.Get()
|
||||
pool.Get()
|
||||
same := pool.Get()
|
||||
first := pool.Lookup("foo.com")
|
||||
same := pool.Lookup("baz.com")
|
||||
|
||||
if !first.Equal(same) {
|
||||
t.Error("should return same ip", first.String())
|
||||
}
|
||||
assert.True(t, first.Equal(same))
|
||||
}
|
||||
|
||||
func TestPool_MaxCacheSize(t *testing.T) {
|
||||
_, ipnet, _ := net.ParseCIDR("192.168.0.1/24")
|
||||
pool, _ := New(ipnet, 2)
|
||||
|
||||
first := pool.Lookup("foo.com")
|
||||
pool.Lookup("bar.com")
|
||||
pool.Lookup("baz.com")
|
||||
next := pool.Lookup("foo.com")
|
||||
|
||||
assert.False(t, first.Equal(next))
|
||||
}
|
||||
|
||||
func TestPool_Error(t *testing.T) {
|
||||
_, ipnet, _ := net.ParseCIDR("192.168.0.1/31")
|
||||
_, err := New(ipnet)
|
||||
_, err := New(ipnet, 10)
|
||||
|
||||
if err == nil {
|
||||
t.Error("should return err")
|
||||
}
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user