Chore: merge branch 'with-tun' into plus-pro

This commit is contained in:
yaling888 2022-04-12 00:47:45 +08:00
commit 571c34f140
17 changed files with 323 additions and 261 deletions

View File

@ -1,7 +1,7 @@
package fakeip
import (
"net"
"net/netip"
"github.com/Dreamacro/clash/component/profile/cachefile"
)
@ -11,22 +11,27 @@ type cachefileStore struct {
}
// GetByHost implements store.GetByHost
func (c *cachefileStore) GetByHost(host string) (net.IP, bool) {
func (c *cachefileStore) GetByHost(host string) (netip.Addr, bool) {
elm := c.cache.GetFakeip([]byte(host))
if elm == nil {
return nil, false
return netip.Addr{}, false
}
if len(elm) == 4 {
return netip.AddrFrom4(*(*[4]byte)(elm)), true
} else {
return netip.AddrFrom16(*(*[16]byte)(elm)), true
}
return net.IP(elm), true
}
// PutByHost implements store.PutByHost
func (c *cachefileStore) PutByHost(host string, ip net.IP) {
c.cache.PutFakeip([]byte(host), ip)
func (c *cachefileStore) PutByHost(host string, ip netip.Addr) {
c.cache.PutFakeip([]byte(host), ip.AsSlice())
}
// GetByIP implements store.GetByIP
func (c *cachefileStore) GetByIP(ip net.IP) (string, bool) {
elm := c.cache.GetFakeip(ip.To4())
func (c *cachefileStore) GetByIP(ip netip.Addr) (string, bool) {
elm := c.cache.GetFakeip(ip.AsSlice())
if elm == nil {
return "", false
}
@ -34,18 +39,18 @@ func (c *cachefileStore) GetByIP(ip net.IP) (string, bool) {
}
// PutByIP implements store.PutByIP
func (c *cachefileStore) PutByIP(ip net.IP, host string) {
c.cache.PutFakeip(ip.To4(), []byte(host))
func (c *cachefileStore) PutByIP(ip netip.Addr, host string) {
c.cache.PutFakeip(ip.AsSlice(), []byte(host))
}
// DelByIP implements store.DelByIP
func (c *cachefileStore) DelByIP(ip net.IP) {
ip = ip.To4()
c.cache.DelFakeipPair(ip, c.cache.GetFakeip(ip.To4()))
func (c *cachefileStore) DelByIP(ip netip.Addr) {
addr := ip.AsSlice()
c.cache.DelFakeipPair(addr, c.cache.GetFakeip(addr))
}
// Exist implements store.Exist
func (c *cachefileStore) Exist(ip net.IP) bool {
func (c *cachefileStore) Exist(ip netip.Addr) bool {
_, exist := c.GetByIP(ip)
return exist
}

View File

@ -1,35 +1,35 @@
package fakeip
import (
"net"
"net/netip"
"github.com/Dreamacro/clash/common/cache"
)
type memoryStore struct {
cacheIP *cache.LruCache[string, net.IP]
cacheHost *cache.LruCache[uint32, string]
cacheIP *cache.LruCache[string, netip.Addr]
cacheHost *cache.LruCache[netip.Addr, string]
}
// GetByHost implements store.GetByHost
func (m *memoryStore) GetByHost(host string) (net.IP, bool) {
func (m *memoryStore) GetByHost(host string) (netip.Addr, bool) {
if ip, exist := m.cacheIP.Get(host); exist {
// ensure ip --> host on head of linked list
m.cacheHost.Get(ipToUint(ip.To4()))
m.cacheHost.Get(ip)
return ip, true
}
return nil, false
return netip.Addr{}, false
}
// PutByHost implements store.PutByHost
func (m *memoryStore) PutByHost(host string, ip net.IP) {
func (m *memoryStore) PutByHost(host string, ip netip.Addr) {
m.cacheIP.Set(host, ip)
}
// GetByIP implements store.GetByIP
func (m *memoryStore) GetByIP(ip net.IP) (string, bool) {
if host, exist := m.cacheHost.Get(ipToUint(ip.To4())); exist {
func (m *memoryStore) GetByIP(ip netip.Addr) (string, bool) {
if host, exist := m.cacheHost.Get(ip); exist {
// ensure host --> ip on head of linked list
m.cacheIP.Get(host)
return host, true
@ -39,22 +39,21 @@ func (m *memoryStore) GetByIP(ip net.IP) (string, bool) {
}
// PutByIP implements store.PutByIP
func (m *memoryStore) PutByIP(ip net.IP, host string) {
m.cacheHost.Set(ipToUint(ip.To4()), host)
func (m *memoryStore) PutByIP(ip netip.Addr, host string) {
m.cacheHost.Set(ip, host)
}
// DelByIP implements store.DelByIP
func (m *memoryStore) DelByIP(ip net.IP) {
ipNum := ipToUint(ip.To4())
if host, exist := m.cacheHost.Get(ipNum); exist {
func (m *memoryStore) DelByIP(ip netip.Addr) {
if host, exist := m.cacheHost.Get(ip); exist {
m.cacheIP.Delete(host)
}
m.cacheHost.Delete(ipNum)
m.cacheHost.Delete(ip)
}
// Exist implements store.Exist
func (m *memoryStore) Exist(ip net.IP) bool {
return m.cacheHost.Exist(ipToUint(ip.To4()))
func (m *memoryStore) Exist(ip netip.Addr) bool {
return m.cacheHost.Exist(ip)
}
// CloneTo implements store.CloneTo
@ -74,7 +73,7 @@ func (m *memoryStore) FlushFakeIP() error {
func newMemoryStore(size int) *memoryStore {
return &memoryStore{
cacheIP: cache.NewLRUCache[string, net.IP](cache.WithSize[string, net.IP](size)),
cacheHost: cache.NewLRUCache[uint32, string](cache.WithSize[uint32, string](size)),
cacheIP: cache.NewLRUCache[string, netip.Addr](cache.WithSize[string, netip.Addr](size)),
cacheHost: cache.NewLRUCache[netip.Addr, string](cache.WithSize[netip.Addr, string](size)),
}
}

View File

@ -2,39 +2,52 @@ package fakeip
import (
"errors"
"net"
"math/bits"
"net/netip"
"sync"
_ "unsafe"
"github.com/Dreamacro/clash/component/profile/cachefile"
"github.com/Dreamacro/clash/component/trie"
)
//go:linkname beUint64 net/netip.beUint64
func beUint64(b []byte) uint64
//go:linkname bePutUint64 net/netip.bePutUint64
func bePutUint64(b []byte, v uint64)
type uint128 struct {
hi uint64
lo uint64
}
type store interface {
GetByHost(host string) (net.IP, bool)
PutByHost(host string, ip net.IP)
GetByIP(ip net.IP) (string, bool)
PutByIP(ip net.IP, host string)
DelByIP(ip net.IP)
Exist(ip net.IP) bool
GetByHost(host string) (netip.Addr, bool)
PutByHost(host string, ip netip.Addr)
GetByIP(ip netip.Addr) (string, bool)
PutByIP(ip netip.Addr, host string)
DelByIP(ip netip.Addr)
Exist(ip netip.Addr) bool
CloneTo(store)
FlushFakeIP() error
}
// Pool is a implementation about fake ip generator without storage
type Pool struct {
max uint32
min uint32
gateway uint32
broadcast uint32
offset uint32
gateway netip.Addr
first netip.Addr
last netip.Addr
offset netip.Addr
cycle bool
mux sync.Mutex
host *trie.DomainTrie[bool]
ipnet *net.IPNet
ipnet *netip.Prefix
store store
}
// Lookup return a fake ip with host
func (p *Pool) Lookup(host string) net.IP {
func (p *Pool) Lookup(host string) netip.Addr {
p.mux.Lock()
defer p.mux.Unlock()
if ip, exist := p.store.GetByHost(host); exist {
@ -47,14 +60,10 @@ func (p *Pool) Lookup(host string) net.IP {
}
// LookBack return host with the fake ip
func (p *Pool) LookBack(ip net.IP) (string, bool) {
func (p *Pool) LookBack(ip netip.Addr) (string, bool) {
p.mux.Lock()
defer p.mux.Unlock()
if ip = ip.To4(); ip == nil {
return "", false
}
return p.store.GetByIP(ip)
}
@ -67,29 +76,25 @@ func (p *Pool) ShouldSkipped(domain string) bool {
}
// Exist returns if given ip exists in fake-ip pool
func (p *Pool) Exist(ip net.IP) bool {
func (p *Pool) Exist(ip netip.Addr) bool {
p.mux.Lock()
defer p.mux.Unlock()
if ip = ip.To4(); ip == nil {
return false
}
return p.store.Exist(ip)
}
// Gateway return gateway ip
func (p *Pool) Gateway() net.IP {
return uintToIP(p.gateway)
func (p *Pool) Gateway() netip.Addr {
return p.gateway
}
// Broadcast return broadcast ip
func (p *Pool) Broadcast() net.IP {
return uintToIP(p.broadcast)
// Broadcast return the last ip
func (p *Pool) Broadcast() netip.Addr {
return p.last
}
// IPNet return raw ipnet
func (p *Pool) IPNet() *net.IPNet {
func (p *Pool) IPNet() *netip.Prefix {
return p.ipnet
}
@ -98,46 +103,28 @@ func (p *Pool) CloneFrom(o *Pool) {
o.store.CloneTo(p.store)
}
func (p *Pool) get(host string) net.IP {
current := p.offset
for {
p.offset = (p.offset + 1) % (p.max - p.min)
// Avoid infinite loops
if p.offset == current {
p.offset = (p.offset + 1) % (p.max - p.min)
ip := uintToIP(p.min + p.offset - 1)
p.store.DelByIP(ip)
break
func (p *Pool) get(host string) netip.Addr {
p.offset = p.offset.Next()
if !p.offset.Less(p.last) {
p.cycle = true
p.offset = p.first
}
ip := uintToIP(p.min + p.offset - 1)
if !p.store.Exist(ip) {
break
if p.cycle {
p.store.DelByIP(p.offset)
}
}
ip := uintToIP(p.min + p.offset - 1)
p.store.PutByIP(ip, host)
return ip
p.store.PutByIP(p.offset, host)
return p.offset
}
func (p *Pool) FlushFakeIP() error {
return p.store.FlushFakeIP()
}
func ipToUint(ip net.IP) uint32 {
v := uint32(ip[0]) << 24
v += uint32(ip[1]) << 16
v += uint32(ip[2]) << 8
v += uint32(ip[3])
return v
}
func uintToIP(v uint32) net.IP {
return net.IP{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
type Options struct {
IPNet *net.IPNet
IPNet *netip.Prefix
Host *trie.DomainTrie[bool]
// Size sets the maximum number of entries in memory
@ -151,21 +138,23 @@ type Options struct {
// New return Pool instance
func New(options Options) (*Pool, error) {
min := ipToUint(options.IPNet.IP) + 3
var (
hostAddr = options.IPNet.Masked().Addr()
gateway = hostAddr.Next()
first = gateway.Next().Next()
last = add(hostAddr, 1<<uint64(hostAddr.BitLen()-options.IPNet.Bits())-1)
)
ones, bits := options.IPNet.Mask.Size()
total := 1<<uint(bits-ones) - 4
if total <= 0 {
if !options.IPNet.IsValid() || !first.Less(last) || !options.IPNet.Contains(last) {
return nil, errors.New("ipnet don't have valid ip")
}
max := min + uint32(total) - 1
pool := &Pool{
min: min,
max: max,
gateway: min - 2,
broadcast: max + 1,
gateway: gateway,
first: first,
last: last,
offset: first.Prev(),
cycle: false,
host: options.Host,
ipnet: options.IPNet,
}
@ -179,3 +168,29 @@ func New(options Options) (*Pool, error) {
return pool, nil
}
// add returns addr + n.
func add(addr netip.Addr, n uint64) netip.Addr {
buf := addr.As16()
u := uint128{
beUint64(buf[:8]),
beUint64(buf[8:]),
}
lo, carry := bits.Add64(u.lo, n, 0)
u.hi = u.hi + carry
u.lo = lo
bePutUint64(buf[:8], u.hi)
bePutUint64(buf[8:], u.lo)
a := netip.AddrFrom16(buf)
if addr.Is4() {
return a.Unmap()
}
return a
}

View File

@ -2,7 +2,7 @@ package fakeip
import (
"fmt"
"net"
"net/netip"
"os"
"testing"
"time"
@ -49,9 +49,9 @@ func createCachefileStore(options Options) (*Pool, string, error) {
}
func TestPool_Basic(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.0/28")
ipnet := netip.MustParsePrefix("192.168.0.0/28")
pools, tempfile, err := createPools(Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 10,
})
assert.Nil(t, err)
@ -62,24 +62,52 @@ func TestPool_Basic(t *testing.T) {
last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last)
assert.True(t, first.Equal(net.IP{192, 168, 0, 3}))
assert.Equal(t, pool.Lookup("foo.com"), net.IP{192, 168, 0, 3})
assert.True(t, last.Equal(net.IP{192, 168, 0, 4}))
assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 3}))
assert.True(t, pool.Lookup("foo.com") == netip.AddrFrom4([4]byte{192, 168, 0, 3}))
assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
assert.True(t, exist)
assert.Equal(t, bar, "bar.com")
assert.Equal(t, pool.Gateway(), net.IP{192, 168, 0, 1})
assert.Equal(t, pool.Broadcast(), net.IP{192, 168, 0, 15})
assert.True(t, pool.Gateway() == netip.AddrFrom4([4]byte{192, 168, 0, 1}))
assert.True(t, pool.Broadcast() == netip.AddrFrom4([4]byte{192, 168, 0, 15}))
assert.Equal(t, pool.IPNet().String(), ipnet.String())
assert.True(t, pool.Exist(net.IP{192, 168, 0, 4}))
assert.False(t, pool.Exist(net.IP{192, 168, 0, 5}))
assert.False(t, pool.Exist(net.ParseIP("::1")))
assert.True(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 4})))
assert.False(t, pool.Exist(netip.AddrFrom4([4]byte{192, 168, 0, 5})))
assert.False(t, pool.Exist(netip.MustParseAddr("::1")))
}
}
func TestPool_BasicV6(t *testing.T) {
ipnet := netip.MustParsePrefix("2001:4860:4860::8888/118")
pools, tempfile, err := createPools(Options{
IPNet: &ipnet,
Size: 10,
})
assert.Nil(t, err)
defer os.Remove(tempfile)
for _, pool := range pools {
first := pool.Lookup("foo.com")
last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last)
assert.True(t, first == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8803"))
assert.True(t, pool.Lookup("foo.com") == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8803"))
assert.True(t, last == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804"))
assert.True(t, exist)
assert.Equal(t, bar, "bar.com")
assert.True(t, pool.Gateway() == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8801"))
assert.True(t, pool.Broadcast() == netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8bff"))
assert.Equal(t, pool.IPNet().String(), ipnet.String())
assert.True(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8804")))
assert.False(t, pool.Exist(netip.MustParseAddr("2001:4860:4860:0000:0000:0000:0000:8805")))
assert.False(t, pool.Exist(netip.MustParseAddr("127.0.0.1")))
}
}
func TestPool_CycleUsed(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.16/28")
ipnet := netip.MustParsePrefix("192.168.0.16/28")
pools, tempfile, err := createPools(Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 10,
})
assert.Nil(t, err)
@ -88,22 +116,22 @@ func TestPool_CycleUsed(t *testing.T) {
for _, pool := range pools {
foo := pool.Lookup("foo.com")
bar := pool.Lookup("bar.com")
for i := 0; i < 9; i++ {
for i := 0; i < 10; i++ {
pool.Lookup(fmt.Sprintf("%d.com", i))
}
baz := pool.Lookup("baz.com")
next := pool.Lookup("foo.com")
assert.True(t, foo.Equal(baz))
assert.True(t, next.Equal(bar))
assert.True(t, foo == baz)
assert.True(t, next == bar)
}
}
func TestPool_Skip(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/29")
ipnet := netip.MustParsePrefix("192.168.0.1/29")
tree := trie.New[bool]()
tree.Insert("example.com", true)
pools, tempfile, err := createPools(Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 10,
Host: tree,
})
@ -117,9 +145,9 @@ func TestPool_Skip(t *testing.T) {
}
func TestPool_MaxCacheSize(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/24")
ipnet := netip.MustParsePrefix("192.168.0.1/24")
pool, _ := New(Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 2,
})
@ -128,13 +156,13 @@ func TestPool_MaxCacheSize(t *testing.T) {
pool.Lookup("baz.com")
next := pool.Lookup("foo.com")
assert.False(t, first.Equal(next))
assert.False(t, first == next)
}
func TestPool_DoubleMapping(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/24")
ipnet := netip.MustParsePrefix("192.168.0.1/24")
pool, _ := New(Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 2,
})
@ -158,23 +186,23 @@ func TestPool_DoubleMapping(t *testing.T) {
assert.False(t, bazExist)
assert.True(t, barExist)
assert.False(t, bazIP.Equal(newBazIP))
assert.False(t, bazIP == newBazIP)
}
func TestPool_Clone(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/24")
ipnet := netip.MustParsePrefix("192.168.0.1/24")
pool, _ := New(Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 2,
})
first := pool.Lookup("foo.com")
last := pool.Lookup("bar.com")
assert.True(t, first.Equal(net.IP{192, 168, 0, 3}))
assert.True(t, last.Equal(net.IP{192, 168, 0, 4}))
assert.True(t, first == netip.AddrFrom4([4]byte{192, 168, 0, 3}))
assert.True(t, last == netip.AddrFrom4([4]byte{192, 168, 0, 4}))
newPool, _ := New(Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 2,
})
newPool.CloneFrom(pool)
@ -185,9 +213,9 @@ func TestPool_Clone(t *testing.T) {
}
func TestPool_Error(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/31")
ipnet := netip.MustParsePrefix("192.168.0.1/31")
_, err := New(Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 10,
})
@ -195,9 +223,9 @@ func TestPool_Error(t *testing.T) {
}
func TestPool_FlushFileCache(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/28")
ipnet := netip.MustParsePrefix("192.168.0.1/28")
pools, tempfile, err := createPools(Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 10,
})
assert.Nil(t, err)
@ -216,18 +244,18 @@ func TestPool_FlushFileCache(t *testing.T) {
next := pool.Lookup("baz.com")
nero := pool.Lookup("foo.com")
assert.Equal(t, foo, fox)
assert.NotEqual(t, foo, baz)
assert.Equal(t, bar, bax)
assert.NotEqual(t, bar, next)
assert.Equal(t, baz, nero)
assert.True(t, foo == fox)
assert.False(t, foo == baz)
assert.True(t, bar == bax)
assert.False(t, bar == next)
assert.True(t, baz == nero)
}
}
func TestPool_FlushMemoryCache(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/28")
ipnet := netip.MustParsePrefix("192.168.0.1/28")
pool, _ := New(Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 10,
})
@ -243,9 +271,9 @@ func TestPool_FlushMemoryCache(t *testing.T) {
next := pool.Lookup("baz.com")
nero := pool.Lookup("foo.com")
assert.Equal(t, foo, fox)
assert.NotEqual(t, foo, baz)
assert.Equal(t, bar, bax)
assert.NotEqual(t, bar, next)
assert.Equal(t, baz, nero)
assert.True(t, foo == fox)
assert.False(t, foo == baz)
assert.True(t, bar == bax)
assert.False(t, bar == next)
assert.True(t, baz == nero)
}

View File

@ -592,7 +592,7 @@ func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) {
}
// add mitm.clash hosts
if err := tree.Insert("mitm.clash", netip.AddrFrom4([4]byte{8, 8, 9, 9})); err != nil {
if err := tree.Insert("mitm.clash", netip.AddrFrom4([4]byte{1, 2, 3, 4})); err != nil {
log.Errorln("insert mitm.clash to host error: %s", err.Error())
}
@ -777,7 +777,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
}
if cfg.EnhancedMode == C.DNSFakeIP {
_, ipnet, err := net.ParseCIDR(cfg.FakeIPRange)
ipnet, err := netip.ParsePrefix(cfg.FakeIPRange)
if err != nil {
return nil, err
}
@ -804,7 +804,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R
}
pool, err := fakeip.New(fakeip.Options{
IPNet: ipnet,
IPNet: &ipnet,
Size: 1000,
Host: host,
Persistence: rawCfg.Profile.StoreFakeIP,

View File

@ -2,6 +2,7 @@ package dns
import (
"net"
"net/netip"
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip"
@ -11,7 +12,7 @@ import (
type ResolverEnhancer struct {
mode C.DNSMode
fakePool *fakeip.Pool
mapping *cache.LruCache[string, string]
mapping *cache.LruCache[netip.Addr, string]
}
func (h *ResolverEnhancer) FakeIPEnabled() bool {
@ -28,7 +29,7 @@ func (h *ResolverEnhancer) IsExistFakeIP(ip net.IP) bool {
}
if pool := h.fakePool; pool != nil {
return pool.Exist(ip)
return pool.Exist(ipToAddr(ip))
}
return false
@ -39,8 +40,10 @@ func (h *ResolverEnhancer) IsFakeIP(ip net.IP) bool {
return false
}
addr := ipToAddr(ip)
if pool := h.fakePool; pool != nil {
return pool.IPNet().Contains(ip) && !pool.Gateway().Equal(ip) && !pool.Broadcast().Equal(ip)
return pool.IPNet().Contains(addr) && addr != pool.Gateway() && addr != pool.Broadcast()
}
return false
@ -52,21 +55,22 @@ func (h *ResolverEnhancer) IsFakeBroadcastIP(ip net.IP) bool {
}
if pool := h.fakePool; pool != nil {
return pool.Broadcast().Equal(ip)
return pool.Broadcast() == ipToAddr(ip)
}
return false
}
func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) {
addr := ipToAddr(ip)
if pool := h.fakePool; pool != nil {
if host, existed := pool.LookBack(ip); existed {
if host, existed := pool.LookBack(addr); existed {
return host, true
}
}
if mapping := h.mapping; mapping != nil {
if host, existed := h.mapping.Get(ip.String()); existed {
if host, existed := h.mapping.Get(addr); existed {
return host, true
}
}
@ -76,7 +80,7 @@ func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) {
func (h *ResolverEnhancer) InsertHostByIP(ip net.IP, host string) {
if mapping := h.mapping; mapping != nil {
h.mapping.Set(ip.String(), host)
h.mapping.Set(ipToAddr(ip), host)
}
}
@ -99,11 +103,11 @@ func (h *ResolverEnhancer) FlushFakeIP() error {
func NewEnhancer(cfg Config) *ResolverEnhancer {
var fakePool *fakeip.Pool
var mapping *cache.LruCache[string, string]
var mapping *cache.LruCache[netip.Addr, string]
if cfg.EnhancedMode != C.DNSNormal {
fakePool = cfg.Pool
mapping = cache.NewLRUCache[string, string](cache.WithSize[string, string](4096), cache.WithStale[string, string](true))
mapping = cache.NewLRUCache[netip.Addr, string](cache.WithSize[netip.Addr, string](4096), cache.WithStale[netip.Addr, string](true))
}
return &ResolverEnhancer{

View File

@ -21,7 +21,7 @@ type (
middleware func(next handler) handler
)
func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[string, string]) middleware {
func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[netip.Addr, string]) middleware {
return func(next handler) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
@ -30,28 +30,25 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[strin
return next(ctx, r)
}
qName := strings.TrimRight(q.Name, ".")
record := hosts.Search(qName)
host := strings.TrimRight(q.Name, ".")
record := hosts.Search(host)
if record == nil {
return next(ctx, r)
}
ip := record.Data
if mapping != nil {
mapping.SetWithExpire(ip.Unmap().String(), qName, time.Now().Add(time.Second*5))
}
msg := r.Copy()
if ip.Is4() && q.Qtype == D.TypeA {
rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 1}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10}
rr.A = ip.AsSlice()
msg.Answer = []D.RR{rr}
} else if ip.Is6() && q.Qtype == D.TypeAAAA {
rr := &D.AAAA{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 1}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10}
rr.AAAA = ip.AsSlice()
msg.Answer = []D.RR{rr}
@ -59,6 +56,10 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[strin
return next(ctx, r)
}
if mapping != nil {
mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*10))
}
ctx.SetType(context.DNSTypeHost)
msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true
@ -69,7 +70,7 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[strin
}
}
func withMapping(mapping *cache.LruCache[string, string]) middleware {
func withMapping(mapping *cache.LruCache[netip.Addr, string]) middleware {
return func(next handler) handler {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
@ -100,7 +101,7 @@ func withMapping(mapping *cache.LruCache[string, string]) middleware {
continue
}
mapping.SetWithExpire(ip.String(), host, time.Now().Add(time.Second*time.Duration(ttl)))
mapping.SetWithExpire(ipToAddr(ip), host, time.Now().Add(time.Second*time.Duration(ttl)))
}
return msg, nil
@ -130,7 +131,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
ip := fakePool.Lookup(host)
rr.A = ip
rr.A = ip.AsSlice()
msg := r.Copy()
msg.Answer = []D.RR{rr}

View File

@ -5,6 +5,7 @@ import (
"crypto/tls"
"fmt"
"net"
"net/netip"
"time"
"github.com/Dreamacro/clash/common/cache"
@ -111,6 +112,22 @@ func msgToIP(msg *D.Msg) []net.IP {
return ips
}
func ipToAddr(ip net.IP) netip.Addr {
if ip == nil {
return netip.Addr{}
}
l := len(ip)
if l == 4 {
return netip.AddrFrom4(*(*[4]byte)(ip))
} else if l == 16 {
return netip.AddrFrom16(*(*[16]byte)(ip))
} else {
return netip.Addr{}
}
}
type wrapPacketConn struct {
net.PacketConn
rAddr net.Addr

View File

@ -195,9 +195,9 @@ func updateRuleProviders(providers map[string]C.Rule) {
}
func updateTun(tun *config.Tun, dns *config.DNS) {
var tunAddressPrefix string
var tunAddressPrefix *netip.Prefix
if dns.FakeIPRange != nil {
tunAddressPrefix = dns.FakeIPRange.IPNet().String()
tunAddressPrefix = dns.FakeIPRange.IPNet()
}
P.ReCreateTun(tun, tunAddressPrefix, tunnel.TCPIn(), tunnel.UDPIn())

View File

@ -70,11 +70,11 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
RemoveExtraHTTPHostPort(request)
if request.URL.Scheme == "" || request.URL.Host == "" {
resp = ResponseWith(request, http.StatusBadRequest)
resp = responseWith(request, http.StatusBadRequest)
} else {
resp, err = client.Do(request)
if err != nil {
resp = ResponseWith(request, http.StatusBadGateway)
resp = responseWith(request, http.StatusBadGateway)
}
}
@ -103,7 +103,7 @@ func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http
if authenticator != nil {
credential := parseBasicProxyAuthorization(request)
if credential == "" {
resp := ResponseWith(request, http.StatusProxyAuthRequired)
resp := responseWith(request, http.StatusProxyAuthRequired)
resp.Header.Set("Proxy-Authenticate", "Basic")
return resp
}
@ -117,14 +117,14 @@ func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http
if !authed {
log.Infoln("Auth failed from %s", request.RemoteAddr)
return ResponseWith(request, http.StatusForbidden)
return responseWith(request, http.StatusForbidden)
}
}
return nil
}
func ResponseWith(request *http.Request, statusCode int) *http.Response {
func responseWith(request *http.Request, statusCode int) *http.Response {
return &http.Response{
StatusCode: statusCode,
Status: http.StatusText(statusCode),

View File

@ -40,7 +40,7 @@ func RemoveExtraHTTPHostPort(req *http.Request) {
host = req.URL.Host
}
if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" {
if pHost, port, err := net.SplitHostPort(host); err == nil && (port == "80" || port == "443") {
host = pHost
}

View File

@ -6,6 +6,7 @@ import (
"crypto/x509"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"sync"
@ -319,7 +320,7 @@ func ReCreateMixed(port int, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.P
log.Infoln("Mixed(http+socks) proxy listening at: %s", mixedListener.Address())
}
func ReCreateTun(tunConf *config.Tun, tunAddressPrefix string, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) {
func ReCreateTun(tunConf *config.Tun, tunAddressPrefix *netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) {
tunMux.Lock()
defer tunMux.Unlock()

View File

@ -17,7 +17,7 @@ import (
"github.com/Dreamacro/clash/common/cache"
N "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant"
httpL "github.com/Dreamacro/clash/listener/http"
H "github.com/Dreamacro/clash/listener/http"
)
func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) {
@ -48,9 +48,12 @@ startOver:
readLoop:
for {
_ = conn.SetDeadline(time.Now().Add(30 * time.Second)) // use SetDeadline instead of Proxy-Connection keep-alive
err := conn.SetDeadline(time.Now().Add(30 * time.Second)) // use SetDeadline instead of Proxy-Connection keep-alive
if err != nil {
break readLoop
}
request, err := httpL.ReadRequest(conn.Reader())
request, err := H.ReadRequest(conn.Reader())
if err != nil {
handleError(opt, nil, err)
break readLoop
@ -58,15 +61,15 @@ readLoop:
var response *http.Response
session := NewSession(conn, request, response)
session := newSession(conn, request, response)
source = parseSourceAddress(session.request, c, source)
request.RemoteAddr = source.String()
session.request.RemoteAddr = source.String()
if !trusted {
response = httpL.Authenticate(request, cache)
session.response = H.Authenticate(session.request, cache)
trusted = response == nil
trusted = session.response == nil
}
if trusted {
@ -84,19 +87,18 @@ readLoop:
break readLoop // close connection
}
buf := make([]byte, session.conn.(*N.BufferedConn).Buffered())
_, _ = session.conn.Read(buf)
buff := make([]byte, session.conn.(*N.BufferedConn).Buffered())
_, _ = session.conn.Read(buff)
mc := &MultiReaderConn{
mrc := &multiReaderConn{
Conn: session.conn,
reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), session.conn),
reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn),
}
// 22 is the TLS handshake.
// https://tools.ietf.org/html/rfc5246#section-6.2.1
if b[0] == 22 {
// TLS handshake.
if b[0] == 0x16 {
// TODO serve by generic host name maybe better?
tlsConn := tls.Server(mc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host))
tlsConn := tls.Server(mrc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host))
// Handshake with the local client
if err = tlsConn.Handshake(); err != nil {
@ -109,15 +111,17 @@ readLoop:
}
// maybe it's the others encrypted connection
in <- inbound.NewHTTPS(request, mc)
in <- inbound.NewHTTPS(session.request, mrc)
}
// maybe it's a http connection
goto readLoop
}
prepareRequest(c, session.request)
// hijack api
if getHostnameWithoutPort(session.request) == opt.ApiHost {
if session.request.URL.Host == opt.ApiHost {
if err = handleApiRequest(session, opt); err != nil {
handleError(opt, session, err)
break readLoop
@ -125,8 +129,6 @@ readLoop:
return
}
prepareRequest(c, session.request)
// hijack custom request and write back custom response if necessary
if opt.Handler != nil {
newReq, newRes := opt.Handler.HandleRequest(session)
@ -144,12 +146,9 @@ readLoop:
}
}
httpL.RemoveHopByHopHeaders(session.request.Header)
httpL.RemoveExtraHTTPHostPort(request)
session.request.RequestURI = ""
if session.request.URL.Scheme == "" || session.request.URL.Host == "" {
if session.request.URL.Host == "" {
session.response = session.NewErrorResponse(errors.New("invalid URL"))
} else {
client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in)
@ -162,6 +161,8 @@ readLoop:
session.response = session.NewErrorResponse(err)
if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") {
// TODO block unsupported host?
_ = writeResponse(session, false)
break readLoop
}
}
}
@ -194,7 +195,7 @@ func writeResponseWithHandler(session *Session, opt *Option) error {
}
func writeResponse(session *Session, keepAlive bool) error {
httpL.RemoveHopByHopHeaders(session.response.Header)
H.RemoveHopByHopHeaders(session.response.Header)
if keepAlive {
session.response.Header.Set("Connection", "keep-alive")
@ -226,17 +227,15 @@ func handleApiRequest(session *Session, opt *Option) error {
return session.response.Write(session.conn)
}
b := `<!DOCTYPE HTML PUBLIC "-
<html>
<head>
<title>Clash ManInTheMiddle Proxy Services - 404 Not Found</title>
</head>
<body>
b := `<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">
<html><head>
<title>Clash MITM Proxy Services - 404 Not Found</title>
</head><body>
<h1>Not Found</h1>
<p>The requested URL %s was not found on this server.</p>
</body>
</html>
</body></html>
`
if opt.Handler != nil {
if opt.Handler.HandleApiRequest(session) {
return nil
@ -261,10 +260,7 @@ func handleApiRequest(session *Session, opt *Option) error {
func handleError(opt *Option, session *Session, err error) {
if opt.Handler != nil {
opt.Handler.HandleError(session, err)
return
}
// log.Errorln("[MITM] process mitm error: %v", err)
}
func prepareRequest(conn net.Conn, request *http.Request) {
@ -277,7 +273,9 @@ func prepareRequest(conn net.Conn, request *http.Request) {
request.URL.Host = request.Host
}
if request.URL.Scheme == "" {
request.URL.Scheme = "http"
}
if tlsConn, ok := conn.(*tls.Conn); ok {
cs := tlsConn.ConnectionState()
@ -289,6 +287,9 @@ func prepareRequest(conn net.Conn, request *http.Request) {
if request.Header.Get("Accept-Encoding") != "" {
request.Header.Set("Accept-Encoding", "gzip")
}
H.RemoveHopByHopHeaders(request.Header)
H.RemoveExtraHTTPHostPort(request)
}
func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool {
@ -303,19 +304,6 @@ func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool {
return false
}
func getHostnameWithoutPort(req *http.Request) string {
host := req.Host
if host == "" {
host = req.URL.Host
}
if pHost, _, err := net.SplitHostPort(host); err == nil {
host = pHost
}
return host
}
func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr {
if source != nil {
return source

View File

@ -1,16 +1,11 @@
package mitm
import (
"fmt"
"io"
"net"
"net/http"
C "github.com/Dreamacro/clash/constant"
)
var serverName = fmt.Sprintf("Clash server (%s)", C.Version)
type Session struct {
conn net.Conn
request *http.Request
@ -37,16 +32,14 @@ func (s *Session) SetProperties(key string, val any) {
}
func (s *Session) NewResponse(code int, body io.Reader) *http.Response {
res := NewResponse(code, body, s.request)
res.Header.Set("Server", serverName)
return res
return NewResponse(code, body, s.request)
}
func (s *Session) NewErrorResponse(err error) *http.Response {
return NewErrorResponse(s.request, err)
}
func NewSession(conn net.Conn, request *http.Request, response *http.Response) *Session {
func newSession(conn net.Conn, request *http.Request, response *http.Response) *Session {
return &Session{
conn: conn,
request: request,

View File

@ -14,12 +14,12 @@ import (
"golang.org/x/text/transform"
)
type MultiReaderConn struct {
type multiReaderConn struct {
net.Conn
reader io.Reader
}
func (c *MultiReaderConn) Read(buf []byte) (int, error) {
func (c *multiReaderConn) Read(buf []byte) (int, error) {
return c.reader.Read(buf)
}
@ -65,7 +65,6 @@ func NewErrorResponse(req *http.Request, err error) *http.Response {
w := fmt.Sprintf(`199 "clash" %q %q`, err.Error(), date)
res.Header.Add("Warning", w)
res.Header.Set("Server", serverName)
return res
}

View File

@ -24,9 +24,9 @@ import (
)
// New TunAdapter
func New(tunConf *config.Tun, tunAddressPrefix string, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) {
func New(tunConf *config.Tun, tunAddressPrefix *netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) {
var (
tunAddress, _ = netip.ParsePrefix(tunAddressPrefix)
tunAddress = netip.Prefix{}
devName = tunConf.Device
stackType = tunConf.Stack
autoRoute = tunConf.AutoRoute
@ -42,6 +42,10 @@ func New(tunConf *config.Tun, tunAddressPrefix string, tcpIn chan<- C.ConnContex
devName = generateDeviceName()
}
if tunAddressPrefix != nil {
tunAddress = *tunAddressPrefix
}
if !tunAddress.IsValid() || !tunAddress.Addr().Is4() {
tunAddress = netip.MustParsePrefix("198.18.0.1/16")
}
@ -144,6 +148,8 @@ func setAtLatest(stackType C.TUNStack, devName string) {
}
switch runtime.GOOS {
case "darwin":
_, _ = cmd.ExecCmd("sysctl net.inet.ip.forwarding=1")
case "windows":
_, _ = cmd.ExecCmd("ipconfig /renew")
case "linux":

View File

@ -182,7 +182,7 @@ func preHandleMetadata(metadata *C.Metadata) error {
return nil
}
func resolveMetadata(ctx C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) {
func resolveMetadata(_ C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) {
switch mode {
case Direct:
proxy = proxies["DIRECT"]
@ -220,7 +220,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
handle := func() bool {
pc := natTable.Get(key)
if pc != nil {
handleUDPToRemote(packet, pc, metadata)
_ = handleUDPToRemote(packet, pc, metadata)
return true
}
return false
@ -289,7 +289,9 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
}
func handleTCPConn(connCtx C.ConnContext) {
defer connCtx.Conn().Close()
defer func(conn net.Conn) {
_ = conn.Close()
}(connCtx.Conn())
metadata := connCtx.Metadata()
if !metadata.Valid() {
@ -307,7 +309,9 @@ func handleTCPConn(connCtx C.ConnContext) {
if MitmOutbound != nil && metadata.Type != C.MITM {
if remoteConn, err1 := MitmOutbound.DialContext(ctx, metadata); err1 == nil {
remoteConn = statistic.NewSniffing(remoteConn, metadata)
defer remoteConn.Close()
defer func(remoteConn C.Conn) {
_ = remoteConn.Close()
}(remoteConn)
handleSocket(connCtx, remoteConn)
return
@ -330,7 +334,9 @@ func handleTCPConn(connCtx C.ConnContext) {
return
}
remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule)
defer remoteConn.Close()
defer func(remoteConn C.Conn) {
_ = remoteConn.Close()
}(remoteConn)
switch true {
case rule != nil: