diff --git a/adapter/outbound/vless.go b/adapter/outbound/vless.go index c60e02b6..84d12666 100644 --- a/adapter/outbound/vless.go +++ b/adapter/outbound/vless.go @@ -152,7 +152,7 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (v *Vless) StreamPacketConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) + ip, err := resolver.ResolveFirstIP(metadata.Host) if err != nil { return nil, errors.New("can't resolve ip") } @@ -245,7 +245,7 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o if v.transport != nil && len(opts) == 0 { // vless use stream-oriented udp with a special address, so we needs a net.UDPAddr if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) + ip, err := resolver.ResolveFirstIP(metadata.Host) if err != nil { return nil, errors.New("can't resolve ip") } diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index 47035a82..d77de253 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -203,7 +203,7 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (v *Vmess) StreamPacketConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) + ip, err := resolver.ResolveFirstIP(metadata.Host) if err != nil { return c, fmt.Errorf("can't resolve ip: %w", err) } @@ -255,7 +255,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o if v.transport != nil && len(opts) == 0 { // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) + ip, err := resolver.ResolveFirstIP(metadata.Host) if err != nil { return nil, fmt.Errorf("can't resolve ip: %w", err) } diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index 7fe625c1..cf8b71c3 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -37,17 +37,17 @@ var ( ) type Resolver interface { - ResolveIP(host string) (ip netip.Addr, err error) - ResolveIPv4(host string) (ip netip.Addr, err error) - ResolveIPv6(host string) (ip netip.Addr, err error) + ResolveIP(host string, random bool) (ip netip.Addr, err error) + ResolveIPv4(host string, random bool) (ip netip.Addr, err error) + ResolveIPv6(host string, random bool) (ip netip.Addr, err error) } // ResolveIPv4 with a host, return ipv4 func ResolveIPv4(host string) (netip.Addr, error) { - return ResolveIPv4WithResolver(host, DefaultResolver) + return ResolveIPv4WithResolver(host, DefaultResolver, true) } -func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) { +func ResolveIPv4WithResolver(host string, r Resolver, random bool) (netip.Addr, error) { if node := DefaultHosts.Search(host); node != nil { if ip := node.Data; ip.Is4() { return ip, nil @@ -56,6 +56,7 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) { ip, err := netip.ParseAddr(host) if err == nil { + ip = ip.Unmap() if ip.Is4() { return ip, nil } @@ -63,7 +64,7 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) { } if r != nil { - return r.ResolveIPv4(host) + return r.ResolveIPv4(host, random) } if DefaultResolver == nil { @@ -76,7 +77,11 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) { return netip.Addr{}, ErrIPNotFound } - ip := ipAddrs[rand.Intn(len(ipAddrs))].To4() + index := 0 + if random { + index = rand.Intn(len(ipAddrs)) + } + ip := ipAddrs[index].To4() if ip == nil { return netip.Addr{}, ErrIPVersion } @@ -89,10 +94,10 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) { // ResolveIPv6 with a host, return ipv6 func ResolveIPv6(host string) (netip.Addr, error) { - return ResolveIPv6WithResolver(host, DefaultResolver) + return ResolveIPv6WithResolver(host, DefaultResolver, true) } -func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) { +func ResolveIPv6WithResolver(host string, r Resolver, random bool) (netip.Addr, error) { if DisableIPv6 { return netip.Addr{}, ErrIPv6Disabled } @@ -112,7 +117,7 @@ func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) { } if r != nil { - return r.ResolveIPv6(host) + return r.ResolveIPv6(host, random) } if DefaultResolver == nil { @@ -125,25 +130,29 @@ func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) { return netip.Addr{}, ErrIPNotFound } - return netip.AddrFrom16(*(*[16]byte)(ipAddrs[rand.Intn(len(ipAddrs))])), nil + index := 0 + if random { + index = rand.Intn(len(ipAddrs)) + } + return netip.AddrFrom16(*(*[16]byte)(ipAddrs[index])), nil } return netip.Addr{}, ErrIPNotFound } // ResolveIPWithResolver same as ResolveIP, but with a resolver -func ResolveIPWithResolver(host string, r Resolver) (netip.Addr, error) { +func ResolveIPWithResolver(host string, r Resolver, random bool) (netip.Addr, error) { if node := DefaultHosts.Search(host); node != nil { return node.Data, nil } if r != nil { if DisableIPv6 { - return r.ResolveIPv4(host) + return r.ResolveIPv4(host, random) } - return r.ResolveIP(host) + return r.ResolveIP(host, random) } else if DisableIPv6 { - return ResolveIPv4(host) + return resolveIP(host, random) } ip, err := netip.ParseAddr(host) @@ -165,13 +174,18 @@ func ResolveIPWithResolver(host string, r Resolver) (netip.Addr, error) { // ResolveIP with a host, return ip func ResolveIP(host string) (netip.Addr, error) { - return ResolveIPWithResolver(host, DefaultResolver) + return resolveIP(host, true) +} + +// ResolveFirstIP with a host, return ip +func ResolveFirstIP(host string) (netip.Addr, error) { + return resolveIP(host, false) } // ResolveIPv4ProxyServerHost proxies server host only func ResolveIPv4ProxyServerHost(host string) (netip.Addr, error) { if ProxyServerHostResolver != nil { - return ResolveIPv4WithResolver(host, ProxyServerHostResolver) + return ResolveIPv4WithResolver(host, ProxyServerHostResolver, true) } return ResolveIPv4(host) } @@ -179,7 +193,7 @@ func ResolveIPv4ProxyServerHost(host string) (netip.Addr, error) { // ResolveIPv6ProxyServerHost proxies server host only func ResolveIPv6ProxyServerHost(host string) (netip.Addr, error) { if ProxyServerHostResolver != nil { - return ResolveIPv6WithResolver(host, ProxyServerHostResolver) + return ResolveIPv6WithResolver(host, ProxyServerHostResolver, true) } return ResolveIPv6(host) } @@ -187,7 +201,11 @@ func ResolveIPv6ProxyServerHost(host string) (netip.Addr, error) { // ResolveProxyServerHost proxies server host only func ResolveProxyServerHost(host string) (netip.Addr, error) { if ProxyServerHostResolver != nil { - return ResolveIPWithResolver(host, ProxyServerHostResolver) + return ResolveIPWithResolver(host, ProxyServerHostResolver, true) } return ResolveIP(host) } + +func resolveIP(host string, random bool) (netip.Addr, error) { + return ResolveIPWithResolver(host, DefaultResolver, random) +} diff --git a/dns/client.go b/dns/client.go index 4795e58c..bd4e968f 100644 --- a/dns/client.go +++ b/dns/client.go @@ -36,7 +36,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) if c.r == nil { return nil, fmt.Errorf("dns %s not a valid ip", c.host) } else { - if ip, err = resolver.ResolveIPWithResolver(c.host, c.r); err != nil { + if ip, err = resolver.ResolveIPWithResolver(c.host, c.r, true); err != nil { return nil, fmt.Errorf("use default dns resolve failed: %w", err) } c.host = ip.String() diff --git a/dns/doh.go b/dns/doh.go index 28c23164..c3ff2999 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -94,7 +94,7 @@ func newDoHClient(url string, r *Resolver, proxyAdapter string) *dohClient { return nil, err } - ip, err := resolver.ResolveIPWithResolver(host, r) + ip, err := resolver.ResolveIPWithResolver(host, r, true) if err != nil { return nil, err } diff --git a/dns/resolver.go b/dns/resolver.go index 27d23d5c..5ac23265 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -21,6 +21,8 @@ import ( "golang.org/x/sync/singleflight" ) +var _ resolver.Resolver = (*Resolver)(nil) + type dnsClient interface { Exchange(m *D.Msg) (msg *D.Msg, err error) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) @@ -45,18 +47,18 @@ type Resolver struct { } // ResolveIP request with TypeA and TypeAAAA, priority return TypeA -func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) { +func (r *Resolver) ResolveIP(host string, random bool) (ip netip.Addr, err error) { ch := make(chan netip.Addr, 1) go func() { defer close(ch) - ip, err := r.resolveIP(host, D.TypeAAAA) + ip, err := r.resolveIP(host, D.TypeAAAA, random) if err != nil { return } ch <- ip }() - ip, err = r.resolveIP(host, D.TypeA) + ip, err = r.resolveIP(host, D.TypeA, random) if err == nil { return } @@ -70,13 +72,13 @@ func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) { } // ResolveIPv4 request with TypeA -func (r *Resolver) ResolveIPv4(host string) (ip netip.Addr, err error) { - return r.resolveIP(host, D.TypeA) +func (r *Resolver) ResolveIPv4(host string, random bool) (ip netip.Addr, err error) { + return r.resolveIP(host, D.TypeA, random) } // ResolveIPv6 request with TypeAAAA -func (r *Resolver) ResolveIPv6(host string) (ip netip.Addr, err error) { - return r.resolveIP(host, D.TypeAAAA) +func (r *Resolver) ResolveIPv6(host string, random bool) (ip netip.Addr, err error) { + return r.resolveIP(host, D.TypeAAAA, random) } func (r *Resolver) shouldIPFallback(ip netip.Addr) bool { @@ -255,9 +257,10 @@ func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err er return } -func (r *Resolver) resolveIP(host string, dnsType uint16) (ip netip.Addr, err error) { +func (r *Resolver) resolveIP(host string, dnsType uint16, random bool) (ip netip.Addr, err error) { ip, err = netip.ParseAddr(host) if err == nil { + ip = ip.Unmap() isIPv4 := ip.Is4() if dnsType == D.TypeAAAA && !isIPv4 { return ip, nil @@ -282,7 +285,12 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip netip.Addr, err er return netip.Addr{}, resolver.ErrIPNotFound } - ip = ips[rand.Intn(ipLength)] + index := 0 + if random { + index = rand.Intn(ipLength) + } + + ip = ips[index] return } diff --git a/tunnel/connection.go b/tunnel/connection.go index 7ab6c4d1..0f7c6bde 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -16,7 +16,7 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata // local resolve UDP dns if !metadata.Resolved() { - ip, err := resolver.ResolveIP(metadata.Host) + ip, err := resolver.ResolveFirstIP(metadata.Host) if err != nil { return err }