From 5999b6262d9ed6504effe2549e1add7769064bc7 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Mon, 11 Apr 2022 06:28:42 +0800 Subject: [PATCH 1/2] Chore: fix typos --- config/config.go | 2 +- dns/middleware.go | 17 ++++---- listener/http/proxy.go | 10 ++--- listener/http/utils.go | 2 +- listener/mitm/proxy.go | 90 +++++++++++++++++----------------------- listener/mitm/session.go | 11 +---- listener/mitm/utils.go | 5 +-- tunnel/tunnel.go | 16 ++++--- 8 files changed, 70 insertions(+), 83 deletions(-) diff --git a/config/config.go b/config/config.go index e86a5e42..12fac706 100644 --- a/config/config.go +++ b/config/config.go @@ -547,7 +547,7 @@ func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) { } // add mitm.clash hosts - if err := tree.Insert("mitm.clash", netip.AddrFrom4([4]byte{8, 8, 9, 9})); err != nil { + if err := tree.Insert("mitm.clash", netip.AddrFrom4([4]byte{1, 2, 3, 4})); err != nil { log.Errorln("insert mitm.clash to host error: %s", err.Error()) } diff --git a/dns/middleware.go b/dns/middleware.go index 4091fa9e..eeb40c1d 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -30,28 +30,25 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[strin return next(ctx, r) } - qName := strings.TrimRight(q.Name, ".") - record := hosts.Search(qName) + host := strings.TrimRight(q.Name, ".") + + record := hosts.Search(host) if record == nil { return next(ctx, r) } ip := record.Data - if mapping != nil { - mapping.SetWithExpire(ip.Unmap().String(), qName, time.Now().Add(time.Second*5)) - } - msg := r.Copy() if ip.Is4() && q.Qtype == D.TypeA { rr := &D.A{} - rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 1} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10} rr.A = ip.AsSlice() msg.Answer = []D.RR{rr} } else if ip.Is6() && q.Qtype == D.TypeAAAA { rr := &D.AAAA{} - rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 1} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10} rr.AAAA = ip.AsSlice() msg.Answer = []D.RR{rr} @@ -59,6 +56,10 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[strin return next(ctx, r) } + if mapping != nil { + mapping.SetWithExpire(ip.Unmap().String(), host, time.Now().Add(time.Second*10)) + } + ctx.SetType(context.DNSTypeHost) msg.SetRcode(r, D.RcodeSuccess) msg.Authoritative = true diff --git a/listener/http/proxy.go b/listener/http/proxy.go index d29f80f5..bd39b8b4 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -70,11 +70,11 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, RemoveExtraHTTPHostPort(request) if request.URL.Scheme == "" || request.URL.Host == "" { - resp = ResponseWith(request, http.StatusBadRequest) + resp = responseWith(request, http.StatusBadRequest) } else { resp, err = client.Do(request) if err != nil { - resp = ResponseWith(request, http.StatusBadGateway) + resp = responseWith(request, http.StatusBadGateway) } } @@ -103,7 +103,7 @@ func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http if authenticator != nil { credential := parseBasicProxyAuthorization(request) if credential == "" { - resp := ResponseWith(request, http.StatusProxyAuthRequired) + resp := responseWith(request, http.StatusProxyAuthRequired) resp.Header.Set("Proxy-Authenticate", "Basic") return resp } @@ -117,14 +117,14 @@ func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http if !authed { log.Infoln("Auth failed from %s", request.RemoteAddr) - return ResponseWith(request, http.StatusForbidden) + return responseWith(request, http.StatusForbidden) } } return nil } -func ResponseWith(request *http.Request, statusCode int) *http.Response { +func responseWith(request *http.Request, statusCode int) *http.Response { return &http.Response{ StatusCode: statusCode, Status: http.StatusText(statusCode), diff --git a/listener/http/utils.go b/listener/http/utils.go index 0e7c7535..94308f19 100644 --- a/listener/http/utils.go +++ b/listener/http/utils.go @@ -40,7 +40,7 @@ func RemoveExtraHTTPHostPort(req *http.Request) { host = req.URL.Host } - if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" { + if pHost, port, err := net.SplitHostPort(host); err == nil && (port == "80" || port == "443") { host = pHost } diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index a0d3bab6..88fef3db 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -17,7 +17,7 @@ import ( "github.com/Dreamacro/clash/common/cache" N "github.com/Dreamacro/clash/common/net" C "github.com/Dreamacro/clash/constant" - httpL "github.com/Dreamacro/clash/listener/http" + H "github.com/Dreamacro/clash/listener/http" ) func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { @@ -48,9 +48,12 @@ startOver: readLoop: for { - _ = conn.SetDeadline(time.Now().Add(30 * time.Second)) // use SetDeadline instead of Proxy-Connection keep-alive + err := conn.SetDeadline(time.Now().Add(30 * time.Second)) // use SetDeadline instead of Proxy-Connection keep-alive + if err != nil { + break readLoop + } - request, err := httpL.ReadRequest(conn.Reader()) + request, err := H.ReadRequest(conn.Reader()) if err != nil { handleError(opt, nil, err) break readLoop @@ -58,15 +61,15 @@ readLoop: var response *http.Response - session := NewSession(conn, request, response) + session := newSession(conn, request, response) source = parseSourceAddress(session.request, c, source) - request.RemoteAddr = source.String() + session.request.RemoteAddr = source.String() if !trusted { - response = httpL.Authenticate(request, cache) + session.response = H.Authenticate(session.request, cache) - trusted = response == nil + trusted = session.response == nil } if trusted { @@ -84,19 +87,18 @@ readLoop: break readLoop // close connection } - buf := make([]byte, session.conn.(*N.BufferedConn).Buffered()) - _, _ = session.conn.Read(buf) + buff := make([]byte, session.conn.(*N.BufferedConn).Buffered()) + _, _ = session.conn.Read(buff) - mc := &MultiReaderConn{ + mrc := &multiReaderConn{ Conn: session.conn, - reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), session.conn), + reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn), } - // 22 is the TLS handshake. - // https://tools.ietf.org/html/rfc5246#section-6.2.1 - if b[0] == 22 { + // TLS handshake. + if b[0] == 0x16 { // TODO serve by generic host name maybe better? - tlsConn := tls.Server(mc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) + tlsConn := tls.Server(mrc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) // Handshake with the local client if err = tlsConn.Handshake(); err != nil { @@ -109,15 +111,17 @@ readLoop: } // maybe it's the others encrypted connection - in <- inbound.NewHTTPS(request, mc) + in <- inbound.NewHTTPS(session.request, mrc) } // maybe it's a http connection goto readLoop } + prepareRequest(c, session.request) + // hijack api - if getHostnameWithoutPort(session.request) == opt.ApiHost { + if session.request.URL.Host == opt.ApiHost { if err = handleApiRequest(session, opt); err != nil { handleError(opt, session, err) break readLoop @@ -125,8 +129,6 @@ readLoop: return } - prepareRequest(c, session.request) - // hijack custom request and write back custom response if necessary if opt.Handler != nil { newReq, newRes := opt.Handler.HandleRequest(session) @@ -144,12 +146,9 @@ readLoop: } } - httpL.RemoveHopByHopHeaders(session.request.Header) - httpL.RemoveExtraHTTPHostPort(request) - session.request.RequestURI = "" - if session.request.URL.Scheme == "" || session.request.URL.Host == "" { + if session.request.URL.Host == "" { session.response = session.NewErrorResponse(errors.New("invalid URL")) } else { client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) @@ -162,6 +161,8 @@ readLoop: session.response = session.NewErrorResponse(err) if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { // TODO block unsupported host? + _ = writeResponse(session, false) + break readLoop } } } @@ -194,7 +195,7 @@ func writeResponseWithHandler(session *Session, opt *Option) error { } func writeResponse(session *Session, keepAlive bool) error { - httpL.RemoveHopByHopHeaders(session.response.Header) + H.RemoveHopByHopHeaders(session.response.Header) if keepAlive { session.response.Header.Set("Connection", "keep-alive") @@ -226,17 +227,15 @@ func handleApiRequest(session *Session, opt *Option) error { return session.response.Write(session.conn) } - b := ` - - Clash ManInTheMiddle Proxy Services - 404 Not Found - - -

Not Found

-

The requested URL %s was not found on this server.

- - + b := ` + +Clash MITM Proxy Services - 404 Not Found + +

Not Found

+

The requested URL %s was not found on this server.

+ ` + if opt.Handler != nil { if opt.Handler.HandleApiRequest(session) { return nil @@ -261,10 +260,7 @@ func handleApiRequest(session *Session, opt *Option) error { func handleError(opt *Option, session *Session, err error) { if opt.Handler != nil { opt.Handler.HandleError(session, err) - return } - - // log.Errorln("[MITM] process mitm error: %v", err) } func prepareRequest(conn net.Conn, request *http.Request) { @@ -277,7 +273,9 @@ func prepareRequest(conn net.Conn, request *http.Request) { request.URL.Host = request.Host } - request.URL.Scheme = "http" + if request.URL.Scheme == "" { + request.URL.Scheme = "http" + } if tlsConn, ok := conn.(*tls.Conn); ok { cs := tlsConn.ConnectionState() @@ -289,6 +287,9 @@ func prepareRequest(conn net.Conn, request *http.Request) { if request.Header.Get("Accept-Encoding") != "" { request.Header.Set("Accept-Encoding", "gzip") } + + H.RemoveHopByHopHeaders(request.Header) + H.RemoveExtraHTTPHostPort(request) } func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool { @@ -303,19 +304,6 @@ func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool { return false } -func getHostnameWithoutPort(req *http.Request) string { - host := req.Host - if host == "" { - host = req.URL.Host - } - - if pHost, _, err := net.SplitHostPort(host); err == nil { - host = pHost - } - - return host -} - func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr { if source != nil { return source diff --git a/listener/mitm/session.go b/listener/mitm/session.go index 2572d879..42c7faf7 100644 --- a/listener/mitm/session.go +++ b/listener/mitm/session.go @@ -1,16 +1,11 @@ package mitm import ( - "fmt" "io" "net" "net/http" - - C "github.com/Dreamacro/clash/constant" ) -var serverName = fmt.Sprintf("Clash server (%s)", C.Version) - type Session struct { conn net.Conn request *http.Request @@ -37,16 +32,14 @@ func (s *Session) SetProperties(key string, val any) { } func (s *Session) NewResponse(code int, body io.Reader) *http.Response { - res := NewResponse(code, body, s.request) - res.Header.Set("Server", serverName) - return res + return NewResponse(code, body, s.request) } func (s *Session) NewErrorResponse(err error) *http.Response { return NewErrorResponse(s.request, err) } -func NewSession(conn net.Conn, request *http.Request, response *http.Response) *Session { +func newSession(conn net.Conn, request *http.Request, response *http.Response) *Session { return &Session{ conn: conn, request: request, diff --git a/listener/mitm/utils.go b/listener/mitm/utils.go index 7d681d42..8ca8054d 100644 --- a/listener/mitm/utils.go +++ b/listener/mitm/utils.go @@ -14,12 +14,12 @@ import ( "golang.org/x/text/transform" ) -type MultiReaderConn struct { +type multiReaderConn struct { net.Conn reader io.Reader } -func (c *MultiReaderConn) Read(buf []byte) (int, error) { +func (c *multiReaderConn) Read(buf []byte) (int, error) { return c.reader.Read(buf) } @@ -65,7 +65,6 @@ func NewErrorResponse(req *http.Request, err error) *http.Response { w := fmt.Sprintf(`199 "clash" %q %q`, err.Error(), date) res.Header.Add("Warning", w) - res.Header.Set("Server", serverName) return res } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index d2cb95be..e16782a7 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -181,7 +181,7 @@ func preHandleMetadata(metadata *C.Metadata) error { return nil } -func resolveMetadata(ctx C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) { +func resolveMetadata(_ C.PlainContext, metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err error) { switch mode { case Direct: proxy = proxies["DIRECT"] @@ -217,7 +217,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) { handle := func() bool { pc := natTable.Get(key) if pc != nil { - handleUDPToRemote(packet, pc, metadata) + _ = handleUDPToRemote(packet, pc, metadata) return true } return false @@ -284,7 +284,9 @@ func handleUDPConn(packet *inbound.PacketAdapter) { } func handleTCPConn(connCtx C.ConnContext) { - defer connCtx.Conn().Close() + defer func(conn net.Conn) { + _ = conn.Close() + }(connCtx.Conn()) metadata := connCtx.Metadata() if !metadata.Valid() { @@ -302,7 +304,9 @@ func handleTCPConn(connCtx C.ConnContext) { if MitmOutbound != nil && metadata.Type != C.MITM { if remoteConn, err1 := MitmOutbound.DialContext(ctx, metadata); err1 == nil { remoteConn = statistic.NewSniffing(remoteConn, metadata) - defer remoteConn.Close() + defer func(remoteConn C.Conn) { + _ = remoteConn.Close() + }(remoteConn) handleSocket(connCtx, remoteConn) return @@ -325,7 +329,9 @@ func handleTCPConn(connCtx C.ConnContext) { return } remoteConn = statistic.NewTCPTracker(remoteConn, statistic.DefaultManager, metadata, rule) - defer remoteConn.Close() + defer func(remoteConn C.Conn) { + _ = remoteConn.Close() + }(remoteConn) switch true { case rule != nil: From 008ee613ab23d8cedef18e689dc5f9c627c1bf87 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Tue, 12 Apr 2022 00:31:04 +0800 Subject: [PATCH 2/2] Refactor: fakeip pool use netip.Prefix, supports ipv6 range --- component/fakeip/cachefile.go | 33 ++++--- component/fakeip/memory.go | 37 ++++---- component/fakeip/pool.go | 167 ++++++++++++++++++---------------- component/fakeip/pool_test.go | 118 +++++++++++++++--------- config/config.go | 4 +- dns/enhancer.go | 22 +++-- dns/middleware.go | 10 +- dns/util.go | 17 ++++ hub/executor/executor.go | 4 +- listener/listener.go | 3 +- listener/tun/tun_adapter.go | 18 ++-- 11 files changed, 254 insertions(+), 179 deletions(-) diff --git a/component/fakeip/cachefile.go b/component/fakeip/cachefile.go index 9e5f22f4..c31d751f 100644 --- a/component/fakeip/cachefile.go +++ b/component/fakeip/cachefile.go @@ -1,7 +1,7 @@ package fakeip import ( - "net" + "net/netip" "github.com/Dreamacro/clash/component/profile/cachefile" ) @@ -11,22 +11,27 @@ type cachefileStore struct { } // GetByHost implements store.GetByHost -func (c *cachefileStore) GetByHost(host string) (net.IP, bool) { +func (c *cachefileStore) GetByHost(host string) (netip.Addr, bool) { elm := c.cache.GetFakeip([]byte(host)) if elm == nil { - return nil, false + return netip.Addr{}, false + } + + if len(elm) == 4 { + return netip.AddrFrom4(*(*[4]byte)(elm)), true + } else { + return netip.AddrFrom16(*(*[16]byte)(elm)), true } - return net.IP(elm), true } // PutByHost implements store.PutByHost -func (c *cachefileStore) PutByHost(host string, ip net.IP) { - c.cache.PutFakeip([]byte(host), ip) +func (c *cachefileStore) PutByHost(host string, ip netip.Addr) { + c.cache.PutFakeip([]byte(host), ip.AsSlice()) } // GetByIP implements store.GetByIP -func (c *cachefileStore) GetByIP(ip net.IP) (string, bool) { - elm := c.cache.GetFakeip(ip.To4()) +func (c *cachefileStore) GetByIP(ip netip.Addr) (string, bool) { + elm := c.cache.GetFakeip(ip.AsSlice()) if elm == nil { return "", false } @@ -34,18 +39,18 @@ func (c *cachefileStore) GetByIP(ip net.IP) (string, bool) { } // PutByIP implements store.PutByIP -func (c *cachefileStore) PutByIP(ip net.IP, host string) { - c.cache.PutFakeip(ip.To4(), []byte(host)) +func (c *cachefileStore) PutByIP(ip netip.Addr, host string) { + c.cache.PutFakeip(ip.AsSlice(), []byte(host)) } // DelByIP implements store.DelByIP -func (c *cachefileStore) DelByIP(ip net.IP) { - ip = ip.To4() - c.cache.DelFakeipPair(ip, c.cache.GetFakeip(ip.To4())) +func (c *cachefileStore) DelByIP(ip netip.Addr) { + addr := ip.AsSlice() + c.cache.DelFakeipPair(addr, c.cache.GetFakeip(addr)) } // Exist implements store.Exist -func (c *cachefileStore) Exist(ip net.IP) bool { +func (c *cachefileStore) Exist(ip netip.Addr) bool { _, exist := c.GetByIP(ip) return exist } diff --git a/component/fakeip/memory.go b/component/fakeip/memory.go index 2568b1d9..5566ce48 100644 --- a/component/fakeip/memory.go +++ b/component/fakeip/memory.go @@ -1,35 +1,35 @@ package fakeip import ( - "net" + "net/netip" "github.com/Dreamacro/clash/common/cache" ) type memoryStore struct { - cacheIP *cache.LruCache[string, net.IP] - cacheHost *cache.LruCache[uint32, string] + cacheIP *cache.LruCache[string, netip.Addr] + cacheHost *cache.LruCache[netip.Addr, string] } // GetByHost implements store.GetByHost -func (m *memoryStore) GetByHost(host string) (net.IP, bool) { +func (m *memoryStore) GetByHost(host string) (netip.Addr, bool) { if ip, exist := m.cacheIP.Get(host); exist { // ensure ip --> host on head of linked list - m.cacheHost.Get(ipToUint(ip.To4())) + m.cacheHost.Get(ip) return ip, true } - return nil, false + return netip.Addr{}, false } // PutByHost implements store.PutByHost -func (m *memoryStore) PutByHost(host string, ip net.IP) { +func (m *memoryStore) PutByHost(host string, ip netip.Addr) { m.cacheIP.Set(host, ip) } // GetByIP implements store.GetByIP -func (m *memoryStore) GetByIP(ip net.IP) (string, bool) { - if host, exist := m.cacheHost.Get(ipToUint(ip.To4())); exist { +func (m *memoryStore) GetByIP(ip netip.Addr) (string, bool) { + if host, exist := m.cacheHost.Get(ip); exist { // ensure host --> ip on head of linked list m.cacheIP.Get(host) return host, true @@ -39,22 +39,21 @@ func (m *memoryStore) GetByIP(ip net.IP) (string, bool) { } // PutByIP implements store.PutByIP -func (m *memoryStore) PutByIP(ip net.IP, host string) { - m.cacheHost.Set(ipToUint(ip.To4()), host) +func (m *memoryStore) PutByIP(ip netip.Addr, host string) { + m.cacheHost.Set(ip, host) } // DelByIP implements store.DelByIP -func (m *memoryStore) DelByIP(ip net.IP) { - ipNum := ipToUint(ip.To4()) - if host, exist := m.cacheHost.Get(ipNum); exist { +func (m *memoryStore) DelByIP(ip netip.Addr) { + if host, exist := m.cacheHost.Get(ip); exist { m.cacheIP.Delete(host) } - m.cacheHost.Delete(ipNum) + m.cacheHost.Delete(ip) } // Exist implements store.Exist -func (m *memoryStore) Exist(ip net.IP) bool { - return m.cacheHost.Exist(ipToUint(ip.To4())) +func (m *memoryStore) Exist(ip netip.Addr) bool { + return m.cacheHost.Exist(ip) } // CloneTo implements store.CloneTo @@ -74,7 +73,7 @@ func (m *memoryStore) FlushFakeIP() error { func newMemoryStore(size int) *memoryStore { return &memoryStore{ - cacheIP: cache.NewLRUCache[string, net.IP](cache.WithSize[string, net.IP](size)), - cacheHost: cache.NewLRUCache[uint32, string](cache.WithSize[uint32, string](size)), + cacheIP: cache.NewLRUCache[string, netip.Addr](cache.WithSize[string, netip.Addr](size)), + cacheHost: cache.NewLRUCache[netip.Addr, string](cache.WithSize[netip.Addr, string](size)), } } diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index afc1691b..2b887fc3 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -2,39 +2,52 @@ package fakeip import ( "errors" - "net" + "math/bits" + "net/netip" "sync" + _ "unsafe" "github.com/Dreamacro/clash/component/profile/cachefile" "github.com/Dreamacro/clash/component/trie" ) +//go:linkname beUint64 net/netip.beUint64 +func beUint64(b []byte) uint64 + +//go:linkname bePutUint64 net/netip.bePutUint64 +func bePutUint64(b []byte, v uint64) + +type uint128 struct { + hi uint64 + lo uint64 +} + type store interface { - GetByHost(host string) (net.IP, bool) - PutByHost(host string, ip net.IP) - GetByIP(ip net.IP) (string, bool) - PutByIP(ip net.IP, host string) - DelByIP(ip net.IP) - Exist(ip net.IP) bool + GetByHost(host string) (netip.Addr, bool) + PutByHost(host string, ip netip.Addr) + GetByIP(ip netip.Addr) (string, bool) + PutByIP(ip netip.Addr, host string) + DelByIP(ip netip.Addr) + Exist(ip netip.Addr) bool CloneTo(store) FlushFakeIP() error } // Pool is a implementation about fake ip generator without storage type Pool struct { - max uint32 - min uint32 - gateway uint32 - broadcast uint32 - offset uint32 - mux sync.Mutex - host *trie.DomainTrie[bool] - ipnet *net.IPNet - store store + gateway netip.Addr + first netip.Addr + last netip.Addr + offset netip.Addr + cycle bool + mux sync.Mutex + host *trie.DomainTrie[bool] + ipnet *netip.Prefix + store store } // Lookup return a fake ip with host -func (p *Pool) Lookup(host string) net.IP { +func (p *Pool) Lookup(host string) netip.Addr { p.mux.Lock() defer p.mux.Unlock() if ip, exist := p.store.GetByHost(host); exist { @@ -47,14 +60,10 @@ func (p *Pool) Lookup(host string) net.IP { } // LookBack return host with the fake ip -func (p *Pool) LookBack(ip net.IP) (string, bool) { +func (p *Pool) LookBack(ip netip.Addr) (string, bool) { p.mux.Lock() defer p.mux.Unlock() - if ip = ip.To4(); ip == nil { - return "", false - } - return p.store.GetByIP(ip) } @@ -67,29 +76,25 @@ func (p *Pool) ShouldSkipped(domain string) bool { } // Exist returns if given ip exists in fake-ip pool -func (p *Pool) Exist(ip net.IP) bool { +func (p *Pool) Exist(ip netip.Addr) bool { p.mux.Lock() defer p.mux.Unlock() - if ip = ip.To4(); ip == nil { - return false - } - return p.store.Exist(ip) } // Gateway return gateway ip -func (p *Pool) Gateway() net.IP { - return uintToIP(p.gateway) +func (p *Pool) Gateway() netip.Addr { + return p.gateway } -// Broadcast return broadcast ip -func (p *Pool) Broadcast() net.IP { - return uintToIP(p.broadcast) +// Broadcast return the last ip +func (p *Pool) Broadcast() netip.Addr { + return p.last } // IPNet return raw ipnet -func (p *Pool) IPNet() *net.IPNet { +func (p *Pool) IPNet() *netip.Prefix { return p.ipnet } @@ -98,46 +103,28 @@ func (p *Pool) CloneFrom(o *Pool) { o.store.CloneTo(p.store) } -func (p *Pool) get(host string) net.IP { - current := p.offset - for { - p.offset = (p.offset + 1) % (p.max - p.min) - // Avoid infinite loops - if p.offset == current { - p.offset = (p.offset + 1) % (p.max - p.min) - ip := uintToIP(p.min + p.offset - 1) - p.store.DelByIP(ip) - break - } +func (p *Pool) get(host string) netip.Addr { + p.offset = p.offset.Next() - ip := uintToIP(p.min + p.offset - 1) - if !p.store.Exist(ip) { - break - } + if !p.offset.Less(p.last) { + p.cycle = true + p.offset = p.first } - ip := uintToIP(p.min + p.offset - 1) - p.store.PutByIP(ip, host) - return ip + + if p.cycle { + p.store.DelByIP(p.offset) + } + + p.store.PutByIP(p.offset, host) + return p.offset } func (p *Pool) FlushFakeIP() error { return p.store.FlushFakeIP() } -func ipToUint(ip net.IP) uint32 { - v := uint32(ip[0]) << 24 - v += uint32(ip[1]) << 16 - v += uint32(ip[2]) << 8 - v += uint32(ip[3]) - return v -} - -func uintToIP(v uint32) net.IP { - return net.IP{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)} -} - type Options struct { - IPNet *net.IPNet + IPNet *netip.Prefix Host *trie.DomainTrie[bool] // Size sets the maximum number of entries in memory @@ -151,23 +138,25 @@ type Options struct { // New return Pool instance func New(options Options) (*Pool, error) { - min := ipToUint(options.IPNet.IP) + 3 + var ( + hostAddr = options.IPNet.Masked().Addr() + gateway = hostAddr.Next() + first = gateway.Next().Next() + last = add(hostAddr, 1<