Refactor: udp would use the first DNS record instead of a random one

This commit is contained in:
yaling888 2022-06-22 03:17:15 +08:00
parent f1fc0ef2ff
commit e31be4edc2
7 changed files with 61 additions and 35 deletions

View File

@ -152,7 +152,7 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
func (v *Vless) StreamPacketConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (v *Vless) StreamPacketConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
// vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr
if !metadata.Resolved() { if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host) ip, err := resolver.ResolveFirstIP(metadata.Host)
if err != nil { if err != nil {
return nil, errors.New("can't resolve ip") return nil, errors.New("can't resolve ip")
} }
@ -245,7 +245,7 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
if v.transport != nil && len(opts) == 0 { if v.transport != nil && len(opts) == 0 {
// vless use stream-oriented udp with a special address, so we needs a net.UDPAddr // vless use stream-oriented udp with a special address, so we needs a net.UDPAddr
if !metadata.Resolved() { if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host) ip, err := resolver.ResolveFirstIP(metadata.Host)
if err != nil { if err != nil {
return nil, errors.New("can't resolve ip") return nil, errors.New("can't resolve ip")
} }

View File

@ -203,7 +203,7 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
func (v *Vmess) StreamPacketConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (v *Vmess) StreamPacketConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
// vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr
if !metadata.Resolved() { if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host) ip, err := resolver.ResolveFirstIP(metadata.Host)
if err != nil { if err != nil {
return c, fmt.Errorf("can't resolve ip: %w", err) return c, fmt.Errorf("can't resolve ip: %w", err)
} }
@ -255,7 +255,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
if v.transport != nil && len(opts) == 0 { if v.transport != nil && len(opts) == 0 {
// vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr // vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr
if !metadata.Resolved() { if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host) ip, err := resolver.ResolveFirstIP(metadata.Host)
if err != nil { if err != nil {
return nil, fmt.Errorf("can't resolve ip: %w", err) return nil, fmt.Errorf("can't resolve ip: %w", err)
} }

View File

@ -37,17 +37,17 @@ var (
) )
type Resolver interface { type Resolver interface {
ResolveIP(host string) (ip netip.Addr, err error) ResolveIP(host string, random bool) (ip netip.Addr, err error)
ResolveIPv4(host string) (ip netip.Addr, err error) ResolveIPv4(host string, random bool) (ip netip.Addr, err error)
ResolveIPv6(host string) (ip netip.Addr, err error) ResolveIPv6(host string, random bool) (ip netip.Addr, err error)
} }
// ResolveIPv4 with a host, return ipv4 // ResolveIPv4 with a host, return ipv4
func ResolveIPv4(host string) (netip.Addr, error) { func ResolveIPv4(host string) (netip.Addr, error) {
return ResolveIPv4WithResolver(host, DefaultResolver) return ResolveIPv4WithResolver(host, DefaultResolver, true)
} }
func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) { func ResolveIPv4WithResolver(host string, r Resolver, random bool) (netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data; ip.Is4() { if ip := node.Data; ip.Is4() {
return ip, nil return ip, nil
@ -56,6 +56,7 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) {
ip, err := netip.ParseAddr(host) ip, err := netip.ParseAddr(host)
if err == nil { if err == nil {
ip = ip.Unmap()
if ip.Is4() { if ip.Is4() {
return ip, nil return ip, nil
} }
@ -63,7 +64,7 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) {
} }
if r != nil { if r != nil {
return r.ResolveIPv4(host) return r.ResolveIPv4(host, random)
} }
if DefaultResolver == nil { if DefaultResolver == nil {
@ -76,7 +77,11 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) {
return netip.Addr{}, ErrIPNotFound return netip.Addr{}, ErrIPNotFound
} }
ip := ipAddrs[rand.Intn(len(ipAddrs))].To4() index := 0
if random {
index = rand.Intn(len(ipAddrs))
}
ip := ipAddrs[index].To4()
if ip == nil { if ip == nil {
return netip.Addr{}, ErrIPVersion return netip.Addr{}, ErrIPVersion
} }
@ -89,10 +94,10 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) {
// ResolveIPv6 with a host, return ipv6 // ResolveIPv6 with a host, return ipv6
func ResolveIPv6(host string) (netip.Addr, error) { func ResolveIPv6(host string) (netip.Addr, error) {
return ResolveIPv6WithResolver(host, DefaultResolver) return ResolveIPv6WithResolver(host, DefaultResolver, true)
} }
func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) { func ResolveIPv6WithResolver(host string, r Resolver, random bool) (netip.Addr, error) {
if DisableIPv6 { if DisableIPv6 {
return netip.Addr{}, ErrIPv6Disabled return netip.Addr{}, ErrIPv6Disabled
} }
@ -112,7 +117,7 @@ func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) {
} }
if r != nil { if r != nil {
return r.ResolveIPv6(host) return r.ResolveIPv6(host, random)
} }
if DefaultResolver == nil { if DefaultResolver == nil {
@ -125,25 +130,29 @@ func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) {
return netip.Addr{}, ErrIPNotFound return netip.Addr{}, ErrIPNotFound
} }
return netip.AddrFrom16(*(*[16]byte)(ipAddrs[rand.Intn(len(ipAddrs))])), nil index := 0
if random {
index = rand.Intn(len(ipAddrs))
}
return netip.AddrFrom16(*(*[16]byte)(ipAddrs[index])), nil
} }
return netip.Addr{}, ErrIPNotFound return netip.Addr{}, ErrIPNotFound
} }
// ResolveIPWithResolver same as ResolveIP, but with a resolver // ResolveIPWithResolver same as ResolveIP, but with a resolver
func ResolveIPWithResolver(host string, r Resolver) (netip.Addr, error) { func ResolveIPWithResolver(host string, r Resolver, random bool) (netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
return node.Data, nil return node.Data, nil
} }
if r != nil { if r != nil {
if DisableIPv6 { if DisableIPv6 {
return r.ResolveIPv4(host) return r.ResolveIPv4(host, random)
} }
return r.ResolveIP(host) return r.ResolveIP(host, random)
} else if DisableIPv6 { } else if DisableIPv6 {
return ResolveIPv4(host) return resolveIP(host, random)
} }
ip, err := netip.ParseAddr(host) ip, err := netip.ParseAddr(host)
@ -165,13 +174,18 @@ func ResolveIPWithResolver(host string, r Resolver) (netip.Addr, error) {
// ResolveIP with a host, return ip // ResolveIP with a host, return ip
func ResolveIP(host string) (netip.Addr, error) { func ResolveIP(host string) (netip.Addr, error) {
return ResolveIPWithResolver(host, DefaultResolver) return resolveIP(host, true)
}
// ResolveFirstIP with a host, return ip
func ResolveFirstIP(host string) (netip.Addr, error) {
return resolveIP(host, false)
} }
// ResolveIPv4ProxyServerHost proxies server host only // ResolveIPv4ProxyServerHost proxies server host only
func ResolveIPv4ProxyServerHost(host string) (netip.Addr, error) { func ResolveIPv4ProxyServerHost(host string) (netip.Addr, error) {
if ProxyServerHostResolver != nil { if ProxyServerHostResolver != nil {
return ResolveIPv4WithResolver(host, ProxyServerHostResolver) return ResolveIPv4WithResolver(host, ProxyServerHostResolver, true)
} }
return ResolveIPv4(host) return ResolveIPv4(host)
} }
@ -179,7 +193,7 @@ func ResolveIPv4ProxyServerHost(host string) (netip.Addr, error) {
// ResolveIPv6ProxyServerHost proxies server host only // ResolveIPv6ProxyServerHost proxies server host only
func ResolveIPv6ProxyServerHost(host string) (netip.Addr, error) { func ResolveIPv6ProxyServerHost(host string) (netip.Addr, error) {
if ProxyServerHostResolver != nil { if ProxyServerHostResolver != nil {
return ResolveIPv6WithResolver(host, ProxyServerHostResolver) return ResolveIPv6WithResolver(host, ProxyServerHostResolver, true)
} }
return ResolveIPv6(host) return ResolveIPv6(host)
} }
@ -187,7 +201,11 @@ func ResolveIPv6ProxyServerHost(host string) (netip.Addr, error) {
// ResolveProxyServerHost proxies server host only // ResolveProxyServerHost proxies server host only
func ResolveProxyServerHost(host string) (netip.Addr, error) { func ResolveProxyServerHost(host string) (netip.Addr, error) {
if ProxyServerHostResolver != nil { if ProxyServerHostResolver != nil {
return ResolveIPWithResolver(host, ProxyServerHostResolver) return ResolveIPWithResolver(host, ProxyServerHostResolver, true)
} }
return ResolveIP(host) return ResolveIP(host)
} }
func resolveIP(host string, random bool) (netip.Addr, error) {
return ResolveIPWithResolver(host, DefaultResolver, random)
}

View File

@ -36,7 +36,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error)
if c.r == nil { if c.r == nil {
return nil, fmt.Errorf("dns %s not a valid ip", c.host) return nil, fmt.Errorf("dns %s not a valid ip", c.host)
} else { } else {
if ip, err = resolver.ResolveIPWithResolver(c.host, c.r); err != nil { if ip, err = resolver.ResolveIPWithResolver(c.host, c.r, true); err != nil {
return nil, fmt.Errorf("use default dns resolve failed: %w", err) return nil, fmt.Errorf("use default dns resolve failed: %w", err)
} }
c.host = ip.String() c.host = ip.String()

View File

@ -94,7 +94,7 @@ func newDoHClient(url string, r *Resolver, proxyAdapter string) *dohClient {
return nil, err return nil, err
} }
ip, err := resolver.ResolveIPWithResolver(host, r) ip, err := resolver.ResolveIPWithResolver(host, r, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -21,6 +21,8 @@ import (
"golang.org/x/sync/singleflight" "golang.org/x/sync/singleflight"
) )
var _ resolver.Resolver = (*Resolver)(nil)
type dnsClient interface { type dnsClient interface {
Exchange(m *D.Msg) (msg *D.Msg, err error) Exchange(m *D.Msg) (msg *D.Msg, err error)
ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error)
@ -45,18 +47,18 @@ type Resolver struct {
} }
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA // ResolveIP request with TypeA and TypeAAAA, priority return TypeA
func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) { func (r *Resolver) ResolveIP(host string, random bool) (ip netip.Addr, err error) {
ch := make(chan netip.Addr, 1) ch := make(chan netip.Addr, 1)
go func() { go func() {
defer close(ch) defer close(ch)
ip, err := r.resolveIP(host, D.TypeAAAA) ip, err := r.resolveIP(host, D.TypeAAAA, random)
if err != nil { if err != nil {
return return
} }
ch <- ip ch <- ip
}() }()
ip, err = r.resolveIP(host, D.TypeA) ip, err = r.resolveIP(host, D.TypeA, random)
if err == nil { if err == nil {
return return
} }
@ -70,13 +72,13 @@ func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) {
} }
// ResolveIPv4 request with TypeA // ResolveIPv4 request with TypeA
func (r *Resolver) ResolveIPv4(host string) (ip netip.Addr, err error) { func (r *Resolver) ResolveIPv4(host string, random bool) (ip netip.Addr, err error) {
return r.resolveIP(host, D.TypeA) return r.resolveIP(host, D.TypeA, random)
} }
// ResolveIPv6 request with TypeAAAA // ResolveIPv6 request with TypeAAAA
func (r *Resolver) ResolveIPv6(host string) (ip netip.Addr, err error) { func (r *Resolver) ResolveIPv6(host string, random bool) (ip netip.Addr, err error) {
return r.resolveIP(host, D.TypeAAAA) return r.resolveIP(host, D.TypeAAAA, random)
} }
func (r *Resolver) shouldIPFallback(ip netip.Addr) bool { func (r *Resolver) shouldIPFallback(ip netip.Addr) bool {
@ -255,9 +257,10 @@ func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err er
return return
} }
func (r *Resolver) resolveIP(host string, dnsType uint16) (ip netip.Addr, err error) { func (r *Resolver) resolveIP(host string, dnsType uint16, random bool) (ip netip.Addr, err error) {
ip, err = netip.ParseAddr(host) ip, err = netip.ParseAddr(host)
if err == nil { if err == nil {
ip = ip.Unmap()
isIPv4 := ip.Is4() isIPv4 := ip.Is4()
if dnsType == D.TypeAAAA && !isIPv4 { if dnsType == D.TypeAAAA && !isIPv4 {
return ip, nil return ip, nil
@ -282,7 +285,12 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip netip.Addr, err er
return netip.Addr{}, resolver.ErrIPNotFound return netip.Addr{}, resolver.ErrIPNotFound
} }
ip = ips[rand.Intn(ipLength)] index := 0
if random {
index = rand.Intn(ipLength)
}
ip = ips[index]
return return
} }

View File

@ -16,7 +16,7 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata
// local resolve UDP dns // local resolve UDP dns
if !metadata.Resolved() { if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host) ip, err := resolver.ResolveFirstIP(metadata.Host)
if err != nil { if err != nil {
return err return err
} }