Refactor: udp would use the first DNS record instead of a random one

This commit is contained in:
yaling888 2022-06-22 03:17:15 +08:00
parent f1fc0ef2ff
commit e31be4edc2
7 changed files with 61 additions and 35 deletions

View File

@ -152,7 +152,7 @@ func (v *Vless) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
func (v *Vless) StreamPacketConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
// vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host)
ip, err := resolver.ResolveFirstIP(metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
@ -245,7 +245,7 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
if v.transport != nil && len(opts) == 0 {
// vless use stream-oriented udp with a special address, so we needs a net.UDPAddr
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host)
ip, err := resolver.ResolveFirstIP(metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}

View File

@ -203,7 +203,7 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
func (v *Vmess) StreamPacketConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
// vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host)
ip, err := resolver.ResolveFirstIP(metadata.Host)
if err != nil {
return c, fmt.Errorf("can't resolve ip: %w", err)
}
@ -255,7 +255,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o
if v.transport != nil && len(opts) == 0 {
// vmess use stream-oriented udp with a special address, so we needs a net.UDPAddr
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host)
ip, err := resolver.ResolveFirstIP(metadata.Host)
if err != nil {
return nil, fmt.Errorf("can't resolve ip: %w", err)
}

View File

@ -37,17 +37,17 @@ var (
)
type Resolver interface {
ResolveIP(host string) (ip netip.Addr, err error)
ResolveIPv4(host string) (ip netip.Addr, err error)
ResolveIPv6(host string) (ip netip.Addr, err error)
ResolveIP(host string, random bool) (ip netip.Addr, err error)
ResolveIPv4(host string, random bool) (ip netip.Addr, err error)
ResolveIPv6(host string, random bool) (ip netip.Addr, err error)
}
// ResolveIPv4 with a host, return ipv4
func ResolveIPv4(host string) (netip.Addr, error) {
return ResolveIPv4WithResolver(host, DefaultResolver)
return ResolveIPv4WithResolver(host, DefaultResolver, true)
}
func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) {
func ResolveIPv4WithResolver(host string, r Resolver, random bool) (netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data; ip.Is4() {
return ip, nil
@ -56,6 +56,7 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) {
ip, err := netip.ParseAddr(host)
if err == nil {
ip = ip.Unmap()
if ip.Is4() {
return ip, nil
}
@ -63,7 +64,7 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) {
}
if r != nil {
return r.ResolveIPv4(host)
return r.ResolveIPv4(host, random)
}
if DefaultResolver == nil {
@ -76,7 +77,11 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) {
return netip.Addr{}, ErrIPNotFound
}
ip := ipAddrs[rand.Intn(len(ipAddrs))].To4()
index := 0
if random {
index = rand.Intn(len(ipAddrs))
}
ip := ipAddrs[index].To4()
if ip == nil {
return netip.Addr{}, ErrIPVersion
}
@ -89,10 +94,10 @@ func ResolveIPv4WithResolver(host string, r Resolver) (netip.Addr, error) {
// ResolveIPv6 with a host, return ipv6
func ResolveIPv6(host string) (netip.Addr, error) {
return ResolveIPv6WithResolver(host, DefaultResolver)
return ResolveIPv6WithResolver(host, DefaultResolver, true)
}
func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) {
func ResolveIPv6WithResolver(host string, r Resolver, random bool) (netip.Addr, error) {
if DisableIPv6 {
return netip.Addr{}, ErrIPv6Disabled
}
@ -112,7 +117,7 @@ func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) {
}
if r != nil {
return r.ResolveIPv6(host)
return r.ResolveIPv6(host, random)
}
if DefaultResolver == nil {
@ -125,25 +130,29 @@ func ResolveIPv6WithResolver(host string, r Resolver) (netip.Addr, error) {
return netip.Addr{}, ErrIPNotFound
}
return netip.AddrFrom16(*(*[16]byte)(ipAddrs[rand.Intn(len(ipAddrs))])), nil
index := 0
if random {
index = rand.Intn(len(ipAddrs))
}
return netip.AddrFrom16(*(*[16]byte)(ipAddrs[index])), nil
}
return netip.Addr{}, ErrIPNotFound
}
// ResolveIPWithResolver same as ResolveIP, but with a resolver
func ResolveIPWithResolver(host string, r Resolver) (netip.Addr, error) {
func ResolveIPWithResolver(host string, r Resolver, random bool) (netip.Addr, error) {
if node := DefaultHosts.Search(host); node != nil {
return node.Data, nil
}
if r != nil {
if DisableIPv6 {
return r.ResolveIPv4(host)
return r.ResolveIPv4(host, random)
}
return r.ResolveIP(host)
return r.ResolveIP(host, random)
} else if DisableIPv6 {
return ResolveIPv4(host)
return resolveIP(host, random)
}
ip, err := netip.ParseAddr(host)
@ -165,13 +174,18 @@ func ResolveIPWithResolver(host string, r Resolver) (netip.Addr, error) {
// ResolveIP with a host, return ip
func ResolveIP(host string) (netip.Addr, error) {
return ResolveIPWithResolver(host, DefaultResolver)
return resolveIP(host, true)
}
// ResolveFirstIP with a host, return ip
func ResolveFirstIP(host string) (netip.Addr, error) {
return resolveIP(host, false)
}
// ResolveIPv4ProxyServerHost proxies server host only
func ResolveIPv4ProxyServerHost(host string) (netip.Addr, error) {
if ProxyServerHostResolver != nil {
return ResolveIPv4WithResolver(host, ProxyServerHostResolver)
return ResolveIPv4WithResolver(host, ProxyServerHostResolver, true)
}
return ResolveIPv4(host)
}
@ -179,7 +193,7 @@ func ResolveIPv4ProxyServerHost(host string) (netip.Addr, error) {
// ResolveIPv6ProxyServerHost proxies server host only
func ResolveIPv6ProxyServerHost(host string) (netip.Addr, error) {
if ProxyServerHostResolver != nil {
return ResolveIPv6WithResolver(host, ProxyServerHostResolver)
return ResolveIPv6WithResolver(host, ProxyServerHostResolver, true)
}
return ResolveIPv6(host)
}
@ -187,7 +201,11 @@ func ResolveIPv6ProxyServerHost(host string) (netip.Addr, error) {
// ResolveProxyServerHost proxies server host only
func ResolveProxyServerHost(host string) (netip.Addr, error) {
if ProxyServerHostResolver != nil {
return ResolveIPWithResolver(host, ProxyServerHostResolver)
return ResolveIPWithResolver(host, ProxyServerHostResolver, true)
}
return ResolveIP(host)
}
func resolveIP(host string, random bool) (netip.Addr, error) {
return ResolveIPWithResolver(host, DefaultResolver, random)
}

View File

@ -36,7 +36,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error)
if c.r == nil {
return nil, fmt.Errorf("dns %s not a valid ip", c.host)
} else {
if ip, err = resolver.ResolveIPWithResolver(c.host, c.r); err != nil {
if ip, err = resolver.ResolveIPWithResolver(c.host, c.r, true); err != nil {
return nil, fmt.Errorf("use default dns resolve failed: %w", err)
}
c.host = ip.String()

View File

@ -94,7 +94,7 @@ func newDoHClient(url string, r *Resolver, proxyAdapter string) *dohClient {
return nil, err
}
ip, err := resolver.ResolveIPWithResolver(host, r)
ip, err := resolver.ResolveIPWithResolver(host, r, true)
if err != nil {
return nil, err
}

View File

@ -21,6 +21,8 @@ import (
"golang.org/x/sync/singleflight"
)
var _ resolver.Resolver = (*Resolver)(nil)
type dnsClient interface {
Exchange(m *D.Msg) (msg *D.Msg, err error)
ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error)
@ -45,18 +47,18 @@ type Resolver struct {
}
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA
func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) {
func (r *Resolver) ResolveIP(host string, random bool) (ip netip.Addr, err error) {
ch := make(chan netip.Addr, 1)
go func() {
defer close(ch)
ip, err := r.resolveIP(host, D.TypeAAAA)
ip, err := r.resolveIP(host, D.TypeAAAA, random)
if err != nil {
return
}
ch <- ip
}()
ip, err = r.resolveIP(host, D.TypeA)
ip, err = r.resolveIP(host, D.TypeA, random)
if err == nil {
return
}
@ -70,13 +72,13 @@ func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) {
}
// ResolveIPv4 request with TypeA
func (r *Resolver) ResolveIPv4(host string) (ip netip.Addr, err error) {
return r.resolveIP(host, D.TypeA)
func (r *Resolver) ResolveIPv4(host string, random bool) (ip netip.Addr, err error) {
return r.resolveIP(host, D.TypeA, random)
}
// ResolveIPv6 request with TypeAAAA
func (r *Resolver) ResolveIPv6(host string) (ip netip.Addr, err error) {
return r.resolveIP(host, D.TypeAAAA)
func (r *Resolver) ResolveIPv6(host string, random bool) (ip netip.Addr, err error) {
return r.resolveIP(host, D.TypeAAAA, random)
}
func (r *Resolver) shouldIPFallback(ip netip.Addr) bool {
@ -255,9 +257,10 @@ func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err er
return
}
func (r *Resolver) resolveIP(host string, dnsType uint16) (ip netip.Addr, err error) {
func (r *Resolver) resolveIP(host string, dnsType uint16, random bool) (ip netip.Addr, err error) {
ip, err = netip.ParseAddr(host)
if err == nil {
ip = ip.Unmap()
isIPv4 := ip.Is4()
if dnsType == D.TypeAAAA && !isIPv4 {
return ip, nil
@ -282,7 +285,12 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip netip.Addr, err er
return netip.Addr{}, resolver.ErrIPNotFound
}
ip = ips[rand.Intn(ipLength)]
index := 0
if random {
index = rand.Intn(ipLength)
}
ip = ips[index]
return
}

View File

@ -16,7 +16,7 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata
// local resolve UDP dns
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(metadata.Host)
ip, err := resolver.ResolveFirstIP(metadata.Host)
if err != nil {
return err
}