Chore: merge branch 'with-tun' into plus-pro

This commit is contained in:
yaling888 2022-04-12 00:47:45 +08:00
commit 571c34f140
17 changed files with 323 additions and 261 deletions

View File

@ -1,7 +1,7 @@
package fakeip package fakeip
import ( import (
"net" "net/netip"
"github.com/Dreamacro/clash/component/profile/cachefile" "github.com/Dreamacro/clash/component/profile/cachefile"
) )
@ -11,22 +11,27 @@ type cachefileStore struct {
} }
// GetByHost implements store.GetByHost // GetByHost implements store.GetByHost
func (c *cachefileStore) GetByHost(host string) (net.IP, bool) { func (c *cachefileStore) GetByHost(host string) (netip.Addr, bool) {
elm := c.cache.GetFakeip([]byte(host)) elm := c.cache.GetFakeip([]byte(host))
if elm == nil { if elm == nil {
return nil, false return netip.Addr{}, false
}
if len(elm) == 4 {
return netip.AddrFrom4(*(*[4]byte)(elm)), true
} else {
return netip.AddrFrom16(*(*[16]byte)(elm)), true
} }
return net.IP(elm), true
} }
// PutByHost implements store.PutByHost // PutByHost implements store.PutByHost
func (c *cachefileStore) PutByHost(host string, ip net.IP) { func (c *cachefileStore) PutByHost(host string, ip netip.Addr) {
c.cache.PutFakeip([]byte(host), ip) c.cache.PutFakeip([]byte(host), ip.AsSlice())
} }
// GetByIP implements store.GetByIP // GetByIP implements store.GetByIP
func (c *cachefileStore) GetByIP(ip net.IP) (string, bool) { func (c *cachefileStore) GetByIP(ip netip.Addr) (string, bool) {
elm := c.cache.GetFakeip(ip.To4()) elm := c.cache.GetFakeip(ip.AsSlice())
if elm == nil { if elm == nil {
return "", false return "", false
} }
@ -34,18 +39,18 @@ func (c *cachefileStore) GetByIP(ip net.IP) (string, bool) {
} }
// PutByIP implements store.PutByIP // PutByIP implements store.PutByIP
func (c *cachefileStore) PutByIP(ip net.IP, host string) { func (c *cachefileStore) PutByIP(ip netip.Addr, host string) {
c.cache.PutFakeip(ip.To4(), []byte(host)) c.cache.PutFakeip(ip.AsSlice(), []byte(host))
} }
// DelByIP implements store.DelByIP // DelByIP implements store.DelByIP
func (c *cachefileStore) DelByIP(ip net.IP) { func (c *cachefileStore) DelByIP(ip netip.Addr) {
ip = ip.To4() addr := ip.AsSlice()
c.cache.DelFakeipPair(ip, c.cache.GetFakeip(ip.To4())) c.cache.DelFakeipPair(addr, c.cache.GetFakeip(addr))
} }
// Exist implements store.Exist // Exist implements store.Exist
func (c *cachefileStore) Exist(ip net.IP) bool { func (c *cachefileStore) Exist(ip netip.Addr) bool {
_, exist := c.GetByIP(ip) _, exist := c.GetByIP(ip)
return exist return exist
} }

View File

@ -1,35 +1,35 @@
package fakeip package fakeip
import ( import (
"net" "net/netip"
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
) )
type memoryStore struct { type memoryStore struct {
cacheIP *cache.LruCache[string, net.IP] cacheIP *cache.LruCache[string, netip.Addr]
cacheHost *cache.LruCache[uint32, string] cacheHost *cache.LruCache[netip.Addr, string]
} }
// GetByHost implements store.GetByHost // GetByHost implements store.GetByHost
func (m *memoryStore) GetByHost(host string) (net.IP, bool) { func (m *memoryStore) GetByHost(host string) (netip.Addr, bool) {
if ip, exist := m.cacheIP.Get(host); exist { if ip, exist := m.cacheIP.Get(host); exist {
// ensure ip --> host on head of linked list // ensure ip --> host on head of linked list
m.cacheHost.Get(ipToUint(ip.To4())) m.cacheHost.Get(ip)
return ip, true return ip, true
} }
return nil, false return netip.Addr{}, false
} }
// PutByHost implements store.PutByHost // PutByHost implements store.PutByHost
func (m *memoryStore) PutByHost(host string, ip net.IP) { func (m *memoryStore) PutByHost(host string, ip netip.Addr) {
m.cacheIP.Set(host, ip) m.cacheIP.Set(host, ip)
} }
// GetByIP implements store.GetByIP // GetByIP implements store.GetByIP
func (m *memoryStore) GetByIP(ip net.IP) (string, bool) { func (m *memoryStore) GetByIP(ip netip.Addr) (string, bool) {
if host, exist := m.cacheHost.Get(ipToUint(ip.To4())); exist { if host, exist := m.cacheHost.Get(ip); exist {
// ensure host --> ip on head of linked list // ensure host --> ip on head of linked list
m.cacheIP.Get(host) m.cacheIP.Get(host)
return host, true return host, true
@ -39,22 +39,21 @@ func (m *memoryStore) GetByIP(ip net.IP) (string, bool) {
} }
// PutByIP implements store.PutByIP // PutByIP implements store.PutByIP
func (m *memoryStore) PutByIP(ip net.IP, host string) { func (m *memoryStore) PutByIP(ip netip.Addr, host string) {
m.cacheHost.Set(ipToUint(ip.To4()), host) m.cacheHost.Set(ip, host)
} }
// DelByIP implements store.DelByIP // DelByIP implements store.DelByIP
func (m *memoryStore) DelByIP(ip net.IP) { func (m *memoryStore) DelByIP(ip netip.Addr) {
ipNum := ipToUint(ip.To4()) if host, exist := m.cacheHost.Get(ip); exist {
if host, exist := m.cacheHost.Get(ipNum); exist {
m.cacheIP.Delete(host) m.cacheIP.Delete(host)
} }
m.cacheHost.Delete(ipNum) m.cacheHost.Delete(ip)
} }
// Exist implements store.Exist // Exist implements store.Exist
func (m *memoryStore) Exist(ip net.IP) bool { func (m *memoryStore) Exist(ip netip.Addr) bool {
return m.cacheHost.Exist(ipToUint(ip.To4())) return m.cacheHost.Exist(ip)
} }
// CloneTo implements store.CloneTo // CloneTo implements store.CloneTo
@ -74,7 +73,7 @@ func (m *memoryStore) FlushFakeIP() error {
func newMemoryStore(size int) *memoryStore { func newMemoryStore(size int) *memoryStore {
return &memoryStore{ return &memoryStore{
cacheIP: cache.NewLRUCache[string, net.IP](cache.WithSize[string, net.IP](size)), cacheIP: cache.NewLRUCache[string, netip.Addr](cache.WithSize[string, netip.Addr](size)),
cacheHost: cache.NewLRUCache[uint32, string](cache.WithSize[uint32, string](size)), cacheHost: cache.NewLRUCache[netip.Addr, string](cache.WithSize[netip.Addr, string](size)),
} }
} }

View File

@ -2,39 +2,52 @@ package fakeip
import ( import (
"errors" "errors"
"net" "math/bits"
"net/netip"
"sync" "sync"
_ "unsafe"
"github.com/Dreamacro/clash/component/profile/cachefile" "github.com/Dreamacro/clash/component/profile/cachefile"
"github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/component/trie"
) )
//go:linkname beUint64 net/netip.beUint64
func beUint64(b []byte) uint64
//go:linkname bePutUint64 net/netip.bePutUint64
func bePutUint64(b []byte, v uint64)
type uint128 struct {
hi uint64
lo uint64
}
type store interface { type store interface {
GetByHost(host string) (net.IP, bool) GetByHost(host string) (netip.Addr, bool)
PutByHost(host string, ip net.IP) PutByHost(host string, ip netip.Addr)
GetByIP(ip net.IP) (string, bool) GetByIP(ip netip.Addr) (string, bool)
PutByIP(ip net.IP, host string) PutByIP(ip netip.Addr, host string)
DelByIP(ip net.IP) DelByIP(ip netip.Addr)
Exist(ip net.IP) bool Exist(ip netip.Addr) bool
CloneTo(store) CloneTo(store)
FlushFakeIP() error FlushFakeIP() error
} }
// Pool is a implementation about fake ip generator without storage // Pool is a implementation about fake ip generator without storage
type Pool struct { type Pool struct {
max uint32 gateway netip.Addr
min uint32 first netip.Addr
gateway uint32 last netip.Addr
broadcast uint32 offset netip.Addr
offset uint32 cycle bool
mux sync.Mutex mux sync.Mutex
host *trie.DomainTrie[bool] host *trie.DomainTrie[bool]
ipnet *net.IPNet ipnet *netip.Prefix
store store store store
} }
// Lookup return a fake ip with host // Lookup return a fake ip with host
func (p *Pool) Lookup(host string) net.IP { func (p *Pool) Lookup(host string) netip.Addr {
p.mux.Lock() p.mux.Lock()
defer p.mux.Unlock() defer p.mux.Unlock()
if ip, exist := p.store.GetByHost(host); exist { if ip, exist := p.store.GetByHost(host); exist {
@ -47,14 +60,10 @@ func (p *Pool) Lookup(host string) net.IP {
} }
// LookBack return host with the fake ip // LookBack return host with the fake ip
func (p *Pool) LookBack(ip net.IP) (string, bool) { func (p *Pool) LookBack(ip netip.Addr) (string, bool) {
p.mux.Lock() p.mux.Lock()
defer p.mux.Unlock() defer p.mux.Unlock()
if ip = ip.To4(); ip == nil {
return "", false
}
return p.store.GetByIP(ip) return p.store.GetByIP(ip)
} }
@ -67,29 +76,25 @@ func (p *Pool) ShouldSkipped(domain string) bool {
} }
// Exist returns if given ip exists in fake-ip pool // Exist returns if given ip exists in fake-ip pool
func (p *Pool) Exist(ip net.IP) bool { func (p *Pool) Exist(ip netip.Addr) bool {
p.mux.Lock() p.mux.Lock()
defer p.mux.Unlock() defer p.mux.Unlock()
if ip = ip.To4(); ip == nil {
return false
}
return p.store.Exist(ip) return p.store.Exist(ip)
} }
// Gateway return gateway ip // Gateway return gateway ip
func (p *Pool) Gateway() net.IP { func (p *Pool) Gateway() netip.Addr {
return uintToIP(p.gateway) return p.gateway
} }
// Broadcast return broadcast ip // Broadcast return the last ip
func (p *Pool) Broadcast() net.IP { func (p *Pool) Broadcast() netip.Addr {
return uintToIP(p.broadcast) return p.last
} }
// IPNet return raw ipnet // IPNet return raw ipnet
func (p *Pool) IPNet() *net.IPNet { func (p *Pool) IPNet() *netip.Prefix {
return p.ipnet return p.ipnet
} }
@ -98,46 +103,28 @@ func (p *Pool) CloneFrom(o *Pool) {
o.store.CloneTo(p.store) o.store.CloneTo(p.store)
} }
func (p *Pool) get(host string) net.IP { func (p *Pool) get(host string) netip.Addr {
current := p.offset p.offset = p.offset.Next()
for {
p.offset = (p.offset + 1) % (p.max - p.min) if !p.offset.Less(p.last) {
// Avoid infinite loops p.cycle = true
if p.offset == current { p.offset = p.first
p.offset = (p.offset + 1) % (p.max - p.min)
ip := uintToIP(p.min + p.offset - 1)
p.store.DelByIP(ip)
break
} }
ip := uintToIP(p.min + p.offset - 1) if p.cycle {
if !p.store.Exist(ip) { p.store.DelByIP(p.offset)
break
} }
}
ip := uintToIP(p.min + p.offset - 1) p.store.PutByIP(p.offset, host)
p.store.PutByIP(ip, host) return p.offset
return ip
} }
func (p *Pool) FlushFakeIP() error { func (p *Pool) FlushFakeIP() error {
return p.store.FlushFakeIP() return p.store.FlushFakeIP()
} }
func ipToUint(ip net.IP) uint32 {
v := uint32(ip[0]) << 24
v += uint32(ip[1]) << 16
v += uint32(ip[2]) << 8
v += uint32(ip[3])
return v
}
func uintToIP(v uint32) net.IP {
return net.IP{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
type Options struct { type Options struct {
IPNet *net.IPNet IPNet *netip.Prefix
Host *trie.DomainTrie[bool] Host *trie.DomainTrie[bool]
// Size sets the maximum number of entries in memory // Size sets the maximum number of entries in memory
@ -151,21 +138,23 @@ type Options struct {
// New return Pool instance // New return Pool instance
func New(options Options) (*Pool, error) { func New(options Options) (*Pool, error) {
min := ipToUint(options.IPNet.IP) + 3 var (
hostAddr = options.IPNet.Masked().Addr()
gateway = hostAddr.Next()
first = gateway.Next().Next()
last = add(hostAddr, 1<<uint64(hostAddr.BitLen()-options.IPNet.Bits())-1)
)
ones, bits := options.IPNet.Mask.Size() if !options.IPNet.IsValid() || !first.Less(last) || !options.IPNet.Contains(last) {
total := 1<<uint(bits-ones) - 4
if total <= 0 {
return nil, errors.New("ipnet don't have valid ip") return nil, errors.New("ipnet don't have valid ip")
} }
max := min + uint32(total) - 1
pool := &Pool{ pool := &Pool{
min: min, gateway: gateway,
max: max, first: first,
gateway: min - 2, last: last,
broadcast: max + 1, offset: first.Prev(),
cycle: false,
host: options.Host, host: options.Host,
ipnet: options.IPNet, ipnet: options.IPNet,
} }
@ -179,3 +168,29 @@ func New(options Options) (*Pool, error) {
return pool, nil return pool, nil
} }
// add returns addr + n.
func add(addr netip.Addr, n uint64) netip.Addr {
buf := addr.As16()
u := uint128{
beUint64(buf[:8]),
beUint64(buf[8:]),
}
lo, carry := bits.Add64(u.lo, n, 0)
u.hi = u.hi + carry
u.lo = lo
bePutUint64(buf[:8], u.hi)
bePutUint64(buf[8:], u.lo)
a := netip.AddrFrom16(buf)
if addr.Is4() {
return a.Unmap()
}
return a
}

View File

@ -2,7 +2,7 @@ package fakeip
import ( import (
"fmt" "fmt"
"net" "net/netip"
"os" "os"
"testing" "testing"
"time" "time"
@ -49,9 +49,9 @@ func createCachefileStore(options Options) (*Pool, string, error) {
} }
func TestPool_Basic(t *testing.T) { func TestPool_Basic(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.0/28") ipnet := netip.MustParsePrefix("192.168.0.0/28")
pools, tempfile, err := createPools(Options{ pools, tempfile, err := createPools(Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 10, Size: 10,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -62,24 +62,52 @@ func TestPool_Basic(t *testing.T) {
last := pool.Lookup("bar.com") last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last) bar, exist := pool.LookBack(last)
assert.True(t, first.Equal(net.IP{192, 168, 0, 3})) assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 3}))
assert.Equal(t, pool.Lookup("foo.com"), net.IP{192, 168, 0, 3}) assert.True(t, pool.Lookup("foo.com") == netip.AddrFrom4([4]byte{192, 168, 0, 3}))
assert.True(t, last.Equal(net.IP{192, 168, 0, 4})) assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.True(t, exist) assert.True(t, exist)
assert.Equal(t, bar, "bar.com") assert.Equal(t, bar, "bar.com")
assert.Equal(t, pool.Gateway(), net.IP{192, 168, 0, 1}) assert.True(t, pool.Gateway() == netip.AddrFrom4([4]byte{192, 168, 0, 1}))
assert.Equal(t, pool.Broadcast(), net.IP{192, 168, 0, 15}) assert.True(t, pool.Broadcast() == netip.AddrFrom4([4]byte{192, 168, 0, 15}))
assert.Equal(t, pool.IPNet().String(), ipnet.String()) assert.Equal(t, pool.IPNet().String(), ipnet.String())
assert.True(t, pool.Exist(net.IP{192, 168, 0, 4})) assert.True(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 4})))
assert.False(t, pool.Exist(net.IP{192, 168, 0, 5})) assert.False(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 5})))
assert.False(t, pool.Exist(net.ParseIP("::1"))) assert.False(t, pool.Exist(netip.MustParseAddr("::1")))
}
}
func TestPool_BasicV6(t *testing.T) {
ipnet := netip.MustParsePrefix("2001:4860:4860::8888/118")
pools, tempfile, err := createPools(Options{
IPNet: &ipnet,
Size: 10,
})
assert.Nil(t, err)
defer os.Remove(tempfile)
for _, pool := range pools {
first := pool.Lookup("foo.com")
last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last)
assert.True(t, first == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8803"))
assert.True(t, pool.Lookup("foo.com") == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8803"))
assert.True(t, last == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.True(t, exist)
assert.Equal(t, bar, "bar.com")
assert.True(t, pool.Gateway() == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8801"))
assert.True(t, pool.Broadcast() == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8bff"))
assert.Equal(t, pool.IPNet().String(), ipnet.String())
assert.True(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804")))
assert.False(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805")))
assert.False(t, pool.Exist(netip.MustParseAddr("127.0.0.1")))
} }
} }
func TestPool_CycleUsed(t *testing.T) { func TestPool_CycleUsed(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.16/28") ipnet := netip.MustParsePrefix("192.168.0.16/28")
pools, tempfile, err := createPools(Options{ pools, tempfile, err := createPools(Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 10, Size: 10,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -88,22 +116,22 @@ func TestPool_CycleUsed(t *testing.T) {
for _, pool := range pools { for _, pool := range pools {
foo := pool.Lookup("foo.com") foo := pool.Lookup("foo.com")
bar := pool.Lookup("bar.com") bar := pool.Lookup("bar.com")
for i := 0; i < 9; i++ { for i := 0; i < 10; i++ {
pool.Lookup(fmt.Sprintf("%d.com", i)) pool.Lookup(fmt.Sprintf("%d.com", i))
} }
baz := pool.Lookup("baz.com") baz := pool.Lookup("baz.com")
next := pool.Lookup("foo.com") next := pool.Lookup("foo.com")
assert.True(t, foo.Equal(baz)) assert.True(t, foo == baz)
assert.True(t, next.Equal(bar)) assert.True(t, next == bar)
} }
} }
func TestPool_Skip(t *testing.T) { func TestPool_Skip(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/29") ipnet := netip.MustParsePrefix("192.168.0.1/29")
tree := trie.New[bool]() tree := trie.New[bool]()
tree.Insert("example.com", true) tree.Insert("example.com", true)
pools, tempfile, err := createPools(Options{ pools, tempfile, err := createPools(Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 10, Size: 10,
Host: tree, Host: tree,
}) })
@ -117,9 +145,9 @@ func TestPool_Skip(t *testing.T) {
} }
func TestPool_MaxCacheSize(t *testing.T) { func TestPool_MaxCacheSize(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/24") ipnet := netip.MustParsePrefix("192.168.0.1/24")
pool, _ := New(Options{ pool, _ := New(Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 2, Size: 2,
}) })
@ -128,13 +156,13 @@ func TestPool_MaxCacheSize(t *testing.T) {
pool.Lookup("baz.com") pool.Lookup("baz.com")
next := pool.Lookup("foo.com") next := pool.Lookup("foo.com")
assert.False(t, first.Equal(next)) assert.False(t, first == next)
} }
func TestPool_DoubleMapping(t *testing.T) { func TestPool_DoubleMapping(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/24") ipnet := netip.MustParsePrefix("192.168.0.1/24")
pool, _ := New(Options{ pool, _ := New(Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 2, Size: 2,
}) })
@ -158,23 +186,23 @@ func TestPool_DoubleMapping(t *testing.T) {
assert.False(t, bazExist) assert.False(t, bazExist)
assert.True(t, barExist) assert.True(t, barExist)
assert.False(t, bazIP.Equal(newBazIP)) assert.False(t, bazIP == newBazIP)
} }
func TestPool_Clone(t *testing.T) { func TestPool_Clone(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/24") ipnet := netip.MustParsePrefix("192.168.0.1/24")
pool, _ := New(Options{ pool, _ := New(Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 2, Size: 2,
}) })
first := pool.Lookup("foo.com") first := pool.Lookup("foo.com")
last := pool.Lookup("bar.com") last := pool.Lookup("bar.com")
assert.True(t, first.Equal(net.IP{192, 168, 0, 3})) assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 3}))
assert.True(t, last.Equal(net.IP{192, 168, 0, 4})) assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
newPool, _ := New(Options{ newPool, _ := New(Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 2, Size: 2,
}) })
newPool.CloneFrom(pool) newPool.CloneFrom(pool)
@ -185,9 +213,9 @@ func TestPool_Clone(t *testing.T) {
} }
func TestPool_Error(t *testing.T) { func TestPool_Error(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/31") ipnet := netip.MustParsePrefix("192.168.0.1/31")
_, err := New(Options{ _, err := New(Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 10, Size: 10,
}) })
@ -195,9 +223,9 @@ func TestPool_Error(t *testing.T) {
} }
func TestPool_FlushFileCache(t *testing.T) { func TestPool_FlushFileCache(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/28") ipnet := netip.MustParsePrefix("192.168.0.1/28")
pools, tempfile, err := createPools(Options{ pools, tempfile, err := createPools(Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 10, Size: 10,
}) })
assert.Nil(t, err) assert.Nil(t, err)
@ -216,18 +244,18 @@ func TestPool_FlushFileCache(t *testing.T) {
next := pool.Lookup("baz.com") next := pool.Lookup("baz.com")
nero := pool.Lookup("foo.com") nero := pool.Lookup("foo.com")
assert.Equal(t, foo, fox) assert.True(t, foo == fox)
assert.NotEqual(t, foo, baz) assert.False(t, foo == baz)
assert.Equal(t, bar, bax) assert.True(t, bar == bax)
assert.NotEqual(t, bar, next) assert.False(t, bar == next)
assert.Equal(t, baz, nero) assert.True(t, baz == nero)
} }
} }
func TestPool_FlushMemoryCache(t *testing.T) { func TestPool_FlushMemoryCache(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/28") ipnet := netip.MustParsePrefix("192.168.0.1/28")
pool, _ := New(Options{ pool, _ := New(Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 10, Size: 10,
}) })
@ -243,9 +271,9 @@ func TestPool_FlushMemoryCache(t *testing.T) {
next := pool.Lookup("baz.com") next := pool.Lookup("baz.com")
nero := pool.Lookup("foo.com") nero := pool.Lookup("foo.com")
assert.Equal(t, foo, fox) assert.True(t, foo == fox)
assert.NotEqual(t, foo, baz) assert.False(t, foo == baz)
assert.Equal(t, bar, bax) assert.True(t, bar == bax)
assert.NotEqual(t, bar, next) assert.False(t, bar == next)
assert.Equal(t, baz, nero) assert.True(t, baz == nero)
} }

View File

@ -592,7 +592,7 @@ func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) {
} }
// add mitm.clash hosts // add mitm.clash hosts
if err := tree.Insert("mitm.clash", netip.AddrFrom4([4]byte{8, 8, 9, 9})); err != nil { if err := tree.Insert("mitm.clash", netip.AddrFrom4([4]byte{1, 2, 3, 4})); err != nil {
log.Errorln("insert mitm.clash to host error: %s", err.Error()) log.Errorln("insert mitm.clash to host error: %s", err.Error())
} }
@ -777,7 +777,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
} }
if cfg.EnhancedMode == C.DNSFakeIP { if cfg.EnhancedMode == C.DNSFakeIP {
_, ipnet, err := net.ParseCIDR(cfg.FakeIPRange) ipnet, err := netip.ParsePrefix(cfg.FakeIPRange)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -804,7 +804,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
} }
pool, err := fakeip.New(fakeip.Options{ pool, err := fakeip.New(fakeip.Options{
IPNet: ipnet, IPNet: &ipnet,
Size: 1000, Size: 1000,
Host: host, Host: host,
Persistence: rawCfg.Profile.StoreFakeIP, Persistence: rawCfg.Profile.StoreFakeIP,

View File

@ -2,6 +2,7 @@ package dns
import ( import (
"net" "net"
"net/netip"
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip" "github.com/Dreamacro/clash/component/fakeip"
@ -11,7 +12,7 @@ import (
type ResolverEnhancer struct { type ResolverEnhancer struct {
mode C.DNSMode mode C.DNSMode
fakePool *fakeip.Pool fakePool *fakeip.Pool
mapping *cache.LruCache[string, string] mapping *cache.LruCache[netip.Addr, string]
} }
func (h *ResolverEnhancer) FakeIPEnabled() bool { func (h *ResolverEnhancer) FakeIPEnabled() bool {
@ -28,7 +29,7 @@ func (h *ResolverEnhancer) IsExistFakeIP(ip net.IP) bool {
} }
if pool := h.fakePool; pool != nil { if pool := h.fakePool; pool != nil {
return pool.Exist(ip) return pool.Exist(ipToAddr(ip))
} }
return false return false
@ -39,8 +40,10 @@ func (h *ResolverEnhancer) IsFakeIP(ip net.IP) bool {
return false return false
} }
addr := ipToAddr(ip)
if pool := h.fakePool; pool != nil { if pool := h.fakePool; pool != nil {
return pool.IPNet().Contains(ip) && !pool.Gateway().Equal(ip) && !pool.Broadcast().Equal(ip) return pool.IPNet().Contains(addr) && addr != pool.Gateway() && addr != pool.Broadcast()
} }
return false return false
@ -52,21 +55,22 @@ func (h *ResolverEnhancer) IsFakeBroadcastIP(ip net.IP) bool {
} }
if pool := h.fakePool; pool != nil { if pool := h.fakePool; pool != nil {
return pool.Broadcast().Equal(ip) return pool.Broadcast() == ipToAddr(ip)
} }
return false return false
} }
func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) { func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) {
addr := ipToAddr(ip)
if pool := h.fakePool; pool != nil { if pool := h.fakePool; pool != nil {
if host, existed := pool.LookBack(ip); existed { if host, existed := pool.LookBack(addr); existed {
return host, true return host, true
} }
} }
if mapping := h.mapping; mapping != nil { if mapping := h.mapping; mapping != nil {
if host, existed := h.mapping.Get(ip.String()); existed { if host, existed := h.mapping.Get(addr); existed {
return host, true return host, true
} }
} }
@ -76,7 +80,7 @@ func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) {
func (h *ResolverEnhancer) InsertHostByIP(ip net.IP, host string) { func (h *ResolverEnhancer) InsertHostByIP(ip net.IP, host string) {
if mapping := h.mapping; mapping != nil { if mapping := h.mapping; mapping != nil {
h.mapping.Set(ip.String(), host) h.mapping.Set(ipToAddr(ip), host)
} }
} }
@ -99,11 +103,11 @@ func (h *ResolverEnhancer) FlushFakeIP() error {
func NewEnhancer(cfg Config) *ResolverEnhancer { func NewEnhancer(cfg Config) *ResolverEnhancer {
var fakePool *fakeip.Pool var fakePool *fakeip.Pool
var mapping *cache.LruCache[string, string] var mapping *cache.LruCache[netip.Addr, string]
if cfg.EnhancedMode != C.DNSNormal { if cfg.EnhancedMode != C.DNSNormal {
fakePool = cfg.Pool fakePool = cfg.Pool
mapping = cache.NewLRUCache[string, string](cache.WithSize[string, string](4096), cache.WithStale[string, string](true)) mapping = cache.NewLRUCache[netip.Addr, string](cache.WithSize[netip.Addr, string](4096), cache.WithStale[netip.Addr, string](true))
} }
return &ResolverEnhancer{ return &ResolverEnhancer{

View File

@ -21,7 +21,7 @@ type (
middleware func(next handler) handler middleware func(next handler) handler
) )
func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[string, string]) middleware { func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[netip.Addr, string]) middleware {
return func(next handler) handler { return func(next handler) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0] q := r.Question[0]
@ -30,28 +30,25 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[strin
return next(ctx, r) return next(ctx, r)
} }
qName := strings.TrimRight(q.Name, ".") host := strings.TrimRight(q.Name, ".")
record := hosts.Search(qName)
record := hosts.Search(host)
if record == nil { if record == nil {
return next(ctx, r) return next(ctx, r)
} }
ip := record.Data ip := record.Data
if mapping != nil {
mapping.SetWithExpire(ip.Unmap().String(), qName, time.Now().Add(time.Second*5))
}
msg := r.Copy() msg := r.Copy()
if ip.Is4() && q.Qtype == D.TypeA { if ip.Is4() && q.Qtype == D.TypeA {
rr := &D.A{} rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 1} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10}
rr.A = ip.AsSlice() rr.A = ip.AsSlice()
msg.Answer = []D.RR{rr} msg.Answer = []D.RR{rr}
} else if ip.Is6() && q.Qtype == D.TypeAAAA { } else if ip.Is6() && q.Qtype == D.TypeAAAA {
rr := &D.AAAA{} rr := &D.AAAA{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 1} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10}
rr.AAAA = ip.AsSlice() rr.AAAA = ip.AsSlice()
msg.Answer = []D.RR{rr} msg.Answer = []D.RR{rr}
@ -59,6 +56,10 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[strin
return next(ctx, r) return next(ctx, r)
} }
if mapping != nil {
mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*10))
}
ctx.SetType(context.DNSTypeHost) ctx.SetType(context.DNSTypeHost)
msg.SetRcode(r, D.RcodeSuccess) msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true msg.Authoritative = true
@ -69,7 +70,7 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[strin
} }
} }
func withMapping(mapping *cache.LruCache[string, string]) middleware { func withMapping(mapping *cache.LruCache[netip.Addr, string]) middleware {
return func(next handler) handler { return func(next handler) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0] q := r.Question[0]
@ -100,7 +101,7 @@ func withMapping(mapping *cache.LruCache[string, string]) middleware {
continue continue
} }
mapping.SetWithExpire(ip.String(), host, time.Now().Add(time.Second*time.Duration(ttl))) mapping.SetWithExpire(ipToAddr(ip), host, time.Now().Add(time.Second*time.Duration(ttl)))
} }
return msg, nil return msg, nil
@ -130,7 +131,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
rr := &D.A{} rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
ip := fakePool.Lookup(host) ip := fakePool.Lookup(host)
rr.A = ip rr.A = ip.AsSlice()
msg := r.Copy() msg := r.Copy()
msg.Answer = []D.RR{rr} msg.Answer = []D.RR{rr}

View File

@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
"net/netip"
"time" "time"
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
@ -111,6 +112,22 @@ func msgToIP(msg *D.Msg) []net.IP {
return ips return ips
} }
func ipToAddr(ip net.IP) netip.Addr {
if ip == nil {
return netip.Addr{}
}
l := len(ip)
if l == 4 {
return netip.AddrFrom4(*(*[4]byte)(ip))
} else if l == 16 {
return netip.AddrFrom16(*(*[16]byte)(ip))
} else {
return netip.Addr{}
}
}
type wrapPacketConn struct { type wrapPacketConn struct {
net.PacketConn net.PacketConn
rAddr net.Addr rAddr net.Addr

View File

@ -195,9 +195,9 @@ func updateRuleProviders(providers map[string]C.Rule) {
} }
func updateTun(tun *config.Tun, dns *config.DNS) { func updateTun(tun *config.Tun, dns *config.DNS) {
var tunAddressPrefix string var tunAddressPrefix *netip.Prefix
if dns.FakeIPRange != nil { if dns.FakeIPRange != nil {
tunAddressPrefix = dns.FakeIPRange.IPNet().String() tunAddressPrefix = dns.FakeIPRange.IPNet()
} }
P.ReCreateTun(tun, tunAddressPrefix, tunnel.TCPIn(), tunnel.UDPIn()) P.ReCreateTun(tun, tunAddressPrefix, tunnel.TCPIn(), tunnel.UDPIn())

View File

@ -70,11 +70,11 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
RemoveExtraHTTPHostPort(request) RemoveExtraHTTPHostPort(request)
if request.URL.Scheme == "" || request.URL.Host == "" { if request.URL.Scheme == "" || request.URL.Host == "" {
resp = ResponseWith(request, http.StatusBadRequest) resp = responseWith(request, http.StatusBadRequest)
} else { } else {
resp, err = client.Do(request) resp, err = client.Do(request)
if err != nil { if err != nil {
resp = ResponseWith(request, http.StatusBadGateway) resp = responseWith(request, http.StatusBadGateway)
} }
} }
@ -103,7 +103,7 @@ func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http
if authenticator != nil { if authenticator != nil {
credential := parseBasicProxyAuthorization(request) credential := parseBasicProxyAuthorization(request)
if credential == "" { if credential == "" {
resp := ResponseWith(request, http.StatusProxyAuthRequired) resp := responseWith(request, http.StatusProxyAuthRequired)
resp.Header.Set("Proxy-Authenticate", "Basic") resp.Header.Set("Proxy-Authenticate", "Basic")
return resp return resp
} }
@ -117,14 +117,14 @@ func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http
if !authed { if !authed {
log.Infoln("Auth failed from %s", request.RemoteAddr) log.Infoln("Auth failed from %s", request.RemoteAddr)
return ResponseWith(request, http.StatusForbidden) return responseWith(request, http.StatusForbidden)
} }
} }
return nil return nil
} }
func ResponseWith(request *http.Request, statusCode int) *http.Response { func responseWith(request *http.Request, statusCode int) *http.Response {
return &http.Response{ return &http.Response{
StatusCode: statusCode, StatusCode: statusCode,
Status: http.StatusText(statusCode), Status: http.StatusText(statusCode),

View File

@ -40,7 +40,7 @@ func RemoveExtraHTTPHostPort(req *http.Request) {
host = req.URL.Host host = req.URL.Host
} }
if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" { if pHost, port, err := net.SplitHostPort(host); err == nil && (port == "80" || port == "443") {
host = pHost host = pHost
} }

View File

@ -6,6 +6,7 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"net" "net"
"net/netip"
"os" "os"
"strconv" "strconv"
"sync" "sync"
@ -319,7 +320,7 @@ func ReCreateMixed(port int, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.P
log.Infoln("Mixed(http+socks) proxy listening at: %s", mixedListener.Address()) log.Infoln("Mixed(http+socks) proxy listening at: %s", mixedListener.Address())
} }
func ReCreateTun(tunConf *config.Tun, tunAddressPrefix string, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) { func ReCreateTun(tunConf *config.Tun, tunAddressPrefix *netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) {
tunMux.Lock() tunMux.Lock()
defer tunMux.Unlock() defer tunMux.Unlock()

View File

@ -17,7 +17,7 @@ import (
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
httpL "github.com/Dreamacro/clash/listener/http" H "github.com/Dreamacro/clash/listener/http"
) )
func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) {
@ -48,9 +48,12 @@ startOver:
readLoop: readLoop:
for { for {
_ = conn.SetDeadline(time.Now().Add(30 * time.Second)) // use SetDeadline instead of Proxy-Connection keep-alive err := conn.SetDeadline(time.Now().Add(30 * time.Second)) // use SetDeadline instead of Proxy-Connection keep-alive
if err != nil {
break readLoop
}
request, err := httpL.ReadRequest(conn.Reader()) request, err := H.ReadRequest(conn.Reader())
if err != nil { if err != nil {
handleError(opt, nil, err) handleError(opt, nil, err)
break readLoop break readLoop
@ -58,15 +61,15 @@ readLoop:
var response *http.Response var response *http.Response
session := NewSession(conn, request, response) session := newSession(conn, request, response)
source = parseSourceAddress(session.request, c, source) source = parseSourceAddress(session.request, c, source)
request.RemoteAddr = source.String() session.request.RemoteAddr = source.String()
if !trusted { if !trusted {
response = httpL.Authenticate(request, cache) session.response = H.Authenticate(session.request, cache)
trusted = response == nil trusted = session.response == nil
} }
if trusted { if trusted {
@ -84,19 +87,18 @@ readLoop:
break readLoop // close connection break readLoop // close connection
} }
buf := make([]byte, session.conn.(*N.BufferedConn).Buffered()) buff := make([]byte, session.conn.(*N.BufferedConn).Buffered())
_, _ = session.conn.Read(buf) _, _ = session.conn.Read(buff)
mc := &MultiReaderConn{ mrc := &multiReaderConn{
Conn: session.conn, Conn: session.conn,
reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), session.conn), reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn),
} }
// 22 is the TLS handshake. // TLS handshake.
// https://tools.ietf.org/html/rfc5246#section-6.2.1 if b[0] == 0x16 {
if b[0] == 22 {
// TODO serve by generic host name maybe better? // TODO serve by generic host name maybe better?
tlsConn := tls.Server(mc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) tlsConn := tls.Server(mrc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host))
// Handshake with the local client // Handshake with the local client
if err = tlsConn.Handshake(); err != nil { if err = tlsConn.Handshake(); err != nil {
@ -109,15 +111,17 @@ readLoop:
} }
// maybe it's the others encrypted connection // maybe it's the others encrypted connection
in <- inbound.NewHTTPS(request, mc) in <- inbound.NewHTTPS(session.request, mrc)
} }
// maybe it's a http connection // maybe it's a http connection
goto readLoop goto readLoop
} }
prepareRequest(c, session.request)
// hijack api // hijack api
if getHostnameWithoutPort(session.request) == opt.ApiHost { if session.request.URL.Host == opt.ApiHost {
if err = handleApiRequest(session, opt); err != nil { if err = handleApiRequest(session, opt); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop break readLoop
@ -125,8 +129,6 @@ readLoop:
return return
} }
prepareRequest(c, session.request)
// hijack custom request and write back custom response if necessary // hijack custom request and write back custom response if necessary
if opt.Handler != nil { if opt.Handler != nil {
newReq, newRes := opt.Handler.HandleRequest(session) newReq, newRes := opt.Handler.HandleRequest(session)
@ -144,12 +146,9 @@ readLoop:
} }
} }
httpL.RemoveHopByHopHeaders(session.request.Header)
httpL.RemoveExtraHTTPHostPort(request)
session.request.RequestURI = "" session.request.RequestURI = ""
if session.request.URL.Scheme == "" || session.request.URL.Host == "" { if session.request.URL.Host == "" {
session.response = session.NewErrorResponse(errors.New("invalid URL")) session.response = session.NewErrorResponse(errors.New("invalid URL"))
} else { } else {
client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in)
@ -162,6 +161,8 @@ readLoop:
session.response = session.NewErrorResponse(err) session.response = session.NewErrorResponse(err)
if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") {
// TODO block unsupported host? // TODO block unsupported host?
_ = writeResponse(session, false)
break readLoop
} }
} }
} }
@ -194,7 +195,7 @@ func writeResponseWithHandler(session *Session, opt *Option) error {
} }
func writeResponse(session *Session, keepAlive bool) error { func writeResponse(session *Session, keepAlive bool) error {
httpL.RemoveHopByHopHeaders(session.response.Header) H.RemoveHopByHopHeaders(session.response.Header)
if keepAlive { if keepAlive {
session.response.Header.Set("Connection", "keep-alive") session.response.Header.Set("Connection", "keep-alive")
@ -226,17 +227,15 @@ func handleApiRequest(session *Session, opt *Option) error {
return session.response.Write(session.conn) return session.response.Write(session.conn)
} }
b := `<!DOCTYPE HTML PUBLIC "- b := `<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">
<html> <html><head>
<head> <title>Clash MITM Proxy Services - 404 Not Found</title>
<title>Clash ManInTheMiddle Proxy Services - 404 Not Found</title> </head><body>
</head>
<body>
<h1>Not Found</h1> <h1>Not Found</h1>
<p>The requested URL %s was not found on this server.</p> <p>The requested URL %s was not found on this server.</p>
</body> </body></html>
</html>
` `
if opt.Handler != nil { if opt.Handler != nil {
if opt.Handler.HandleApiRequest(session) { if opt.Handler.HandleApiRequest(session) {
return nil return nil
@ -261,10 +260,7 @@ func handleApiRequest(session *Session, opt *Option) error {
func handleError(opt *Option, session *Session, err error) { func handleError(opt *Option, session *Session, err error) {
if opt.Handler != nil { if opt.Handler != nil {
opt.Handler.HandleError(session, err) opt.Handler.HandleError(session, err)
return
} }
// log.Errorln("[MITM] process mitm error: %v", err)
} }
func prepareRequest(conn net.Conn, request *http.Request) { func prepareRequest(conn net.Conn, request *http.Request) {
@ -277,7 +273,9 @@ func prepareRequest(conn net.Conn, request *http.Request) {
request.URL.Host = request.Host request.URL.Host = request.Host
} }
if request.URL.Scheme == "" {
request.URL.Scheme = "http" request.URL.Scheme = "http"
}
if tlsConn, ok := conn.(*tls.Conn); ok { if tlsConn, ok := conn.(*tls.Conn); ok {
cs := tlsConn.ConnectionState() cs := tlsConn.ConnectionState()
@ -289,6 +287,9 @@ func prepareRequest(conn net.Conn, request *http.Request) {
if request.Header.Get("Accept-Encoding") != "" { if request.Header.Get("Accept-Encoding") != "" {
request.Header.Set("Accept-Encoding", "gzip") request.Header.Set("Accept-Encoding", "gzip")
} }
H.RemoveHopByHopHeaders(request.Header)
H.RemoveExtraHTTPHostPort(request)
} }
func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool { func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool {
@ -303,19 +304,6 @@ func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool {
return false return false
} }
func getHostnameWithoutPort(req *http.Request) string {
host := req.Host
if host == "" {
host = req.URL.Host
}
if pHost, _, err := net.SplitHostPort(host); err == nil {
host = pHost
}
return host
}
func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr { func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr {
if source != nil { if source != nil {
return source return source

View File

@ -1,16 +1,11 @@
package mitm package mitm
import ( import (
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
C "github.com/Dreamacro/clash/constant"
) )
var serverName = fmt.Sprintf("Clash server (%s)", C.Version)
type Session struct { type Session struct {
conn net.Conn conn net.Conn
request *http.Request request *http.Request
@ -37,16 +32,14 @@ func (s *Session) SetProperties(key string, val any) {
} }
func (s *Session) NewResponse(code int, body io.Reader) *http.Response { func (s *Session) NewResponse(code int, body io.Reader) *http.Response {
res := NewResponse(code, body, s.request) return NewResponse(code, body, s.request)
res.Header.Set("Server", serverName)
return res
} }
func (s *Session) NewErrorResponse(err error) *http.Response { func (s *Session) NewErrorResponse(err error) *http.Response {
return NewErrorResponse(s.request, err) return NewErrorResponse(s.request, err)
} }
func NewSession(conn net.Conn, request *http.Request, response *http.Response) *Session { func newSession(conn net.Conn, request *http.Request, response *http.Response) *Session {
return &Session{ return &Session{
conn: conn, conn: conn,
request: request, request: request,

View File

@ -14,12 +14,12 @@ import (
"golang.org/x/text/transform" "golang.org/x/text/transform"
) )
type MultiReaderConn struct { type multiReaderConn struct {
net.Conn net.Conn
reader io.Reader reader io.Reader
} }
func (c *MultiReaderConn) Read(buf []byte) (int, error) { func (c *multiReaderConn) Read(buf []byte) (int, error) {
return c.reader.Read(buf) return c.reader.Read(buf)
} }
@ -65,7 +65,6 @@ func NewErrorResponse(req *http.Request, err error) *http.Response {
w := fmt.Sprintf(`199 "clash" %q %q`, err.Error(), date) w := fmt.Sprintf(`199 "clash" %q %q`, err.Error(), date)
res.Header.Add("Warning", w) res.Header.Add("Warning", w)
res.Header.Set("Server", serverName)
return res return res
} }

View File

@ -24,9 +24,9 @@ import (
) )
// New TunAdapter // New TunAdapter
func New(tunConf *config.Tun, tunAddressPrefix string, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) { func New(tunConf *config.Tun, tunAddressPrefix *netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) {
var ( var (
tunAddress, _ = netip.ParsePrefix(tunAddressPrefix) tunAddress = netip.Prefix{}
devName = tunConf.Device devName = tunConf.Device
stackType = tunConf.Stack stackType = tunConf.Stack
autoRoute = tunConf.AutoRoute autoRoute = tunConf.AutoRoute
@ -42,6 +42,10 @@ func New(tunConf *config.Tun, tunAddressPrefix string, tcpIn chan<- C.ConnContex
devName = generateDeviceName() devName = generateDeviceName()
} }
if tunAddressPrefix != nil {
tunAddress = *tunAddressPrefix
}
if !tunAddress.IsValid() || !tunAddress.Addr().Is4() { if !tunAddress.IsValid() || !tunAddress.Addr().Is4() {
tunAddress = netip.MustParsePrefix("198.18.0.1/16") tunAddress = netip.MustParsePrefix("198.18.0.1/16")
} }
@ -144,6 +148,8 @@ func setAtLatest(stackType C.TUNStack, devName string) {
} }
switch runtime.GOOS { switch runtime.GOOS {
case "darwin":
_, _ = cmd.ExecCmd("sysctl net.inet.ip.forwarding=1")
case "windows": case "windows":
_, _ = cmd.ExecCmd("ipconfig /renew") _, _ = cmd.ExecCmd("ipconfig /renew")
case "linux": case "linux":

View File

@ -182,7 +182,7 @@ func preHandleMetadata(metadata *C.Metadata) error {
return nil return nil
} }
func resolveMetadata(ctx C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) { func resolveMetadata(_ C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) {
switch mode { switch mode {
case Direct: case Direct:
proxy = proxies["DIRECT"] proxy = proxies["DIRECT"]
@ -220,7 +220,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
handle := func() bool { handle := func() bool {
pc := natTable.Get(key) pc := natTable.Get(key)
if pc != nil { if pc != nil {
handleUDPToRemote(packet, pc, metadata) _ = handleUDPToRemote(packet, pc, metadata)
return true return true
} }
return false return false
@ -289,7 +289,9 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
} }
func handleTCPConn(connCtx C.ConnContext) { func handleTCPConn(connCtx C.ConnContext) {
defer connCtx.Conn().Close() defer func(conn net.Conn) {
_ = conn.Close()
}(connCtx.Conn())
metadata := connCtx.Metadata() metadata := connCtx.Metadata()
if !metadata.Valid() { if !metadata.Valid() {
@ -307,7 +309,9 @@ func handleTCPConn(connCtx C.ConnContext) {
if MitmOutbound != nil && metadata.Type != C.MITM { if MitmOutbound != nil && metadata.Type != C.MITM {
if remoteConn, err1 := MitmOutbound.DialContext(ctx, metadata); err1 == nil { if remoteConn, err1 := MitmOutbound.DialContext(ctx, metadata); err1 == nil {
remoteConn = statistic.NewSniffing(remoteConn, metadata) remoteConn = statistic.NewSniffing(remoteConn, metadata)
defer remoteConn.Close() defer func(remoteConn C.Conn) {
_ = remoteConn.Close()
}(remoteConn)
handleSocket(connCtx, remoteConn) handleSocket(connCtx, remoteConn)
return return
@ -330,7 +334,9 @@ func handleTCPConn(connCtx C.ConnContext) {
return return
} }
remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule) remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule)
defer remoteConn.Close() defer func(remoteConn C.Conn) {
_ = remoteConn.Close()
}(remoteConn)
switch true { switch true {
case rule != nil: case rule != nil: