chore: always pass context when resolve dns

This commit is contained in:
wwqgtxx
2022-11-12 13:18:36 +08:00
parent dbadf37823
commit 901a47318d
20 changed files with 156 additions and 135 deletions

View File

@ -38,7 +38,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error)
if c.r == nil {
return nil, fmt.Errorf("dns %s not a valid ip", c.host)
} else {
if ip, err = resolver.ResolveIPWithResolver(c.host, c.r); err != nil {
if ip, err = resolver.ResolveIPWithResolver(ctx, c.host, c.r); err != nil {
return nil, fmt.Errorf("use default dns resolve failed: %w", err)
}
c.host = ip.String()

View File

@ -347,7 +347,7 @@ func (doq *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
if err != nil {
return nil, err
}
conn, err := dialContextExtra(ctx, doq.proxyAdapter, "udp", ipAddr, port)
if err != nil {
return nil, err
@ -505,7 +505,7 @@ func getDialHandler(r *Resolver, proxyAdapter string) dialHandler {
if err != nil {
return nil, err
}
ip, err := r.ResolveIP(host)
ip, err := r.ResolveIP(ctx, host)
if err != nil {
return nil, err
}

View File

@ -156,7 +156,7 @@ func withResolver(resolver *Resolver) handler {
return handleMsgWithEmptyAnswer(r), nil
}
msg, err := resolver.Exchange(r)
msg, err := resolver.ExchangeContext(ctx, r)
if err != nil {
log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err)
return msg, err

View File

@ -1,14 +1,18 @@
package dns
import D "github.com/miekg/dns"
import (
"context"
D "github.com/miekg/dns"
)
type LocalServer struct {
handler handler
}
// ServeMsg implement resolver.LocalServer ResolveMsg
func (s *LocalServer) ServeMsg(msg *D.Msg) (*D.Msg, error) {
return handlerWithContext(s.handler, msg)
func (s *LocalServer) ServeMsg(ctx context.Context, msg *D.Msg) (*D.Msg, error) {
return handlerWithContext(ctx, s.handler, msg)
}
func NewLocalServer(resolver *Resolver, mapper *ResolverEnhancer) *LocalServer {

View File

@ -44,18 +44,18 @@ type Resolver struct {
proxyServer []dnsClient
}
func (r *Resolver) ResolveAllIPPrimaryIPv4(host string) (ips []netip.Addr, err error) {
func (r *Resolver) LookupIPPrimaryIPv4(ctx context.Context, host string) (ips []netip.Addr, err error) {
ch := make(chan []netip.Addr, 1)
go func() {
defer close(ch)
ip, err := r.resolveIP(host, D.TypeAAAA)
ip, err := r.resolveIP(ctx, host, D.TypeAAAA)
if err != nil {
return
}
ch <- ip
}()
ips, err = r.resolveIP(host, D.TypeA)
ips, err = r.resolveIP(ctx, host, D.TypeA)
if err == nil {
return
}
@ -68,11 +68,11 @@ func (r *Resolver) ResolveAllIPPrimaryIPv4(host string) (ips []netip.Addr, err e
return ip, nil
}
func (r *Resolver) ResolveAllIP(host string) (ips []netip.Addr, err error) {
func (r *Resolver) LookupIP(ctx context.Context, host string) (ips []netip.Addr, err error) {
ch := make(chan []netip.Addr, 1)
go func() {
defer close(ch)
ip, err := r.resolveIP(host, D.TypeAAAA)
ip, err := r.resolveIP(ctx, host, D.TypeAAAA)
if err != nil {
return
}
@ -80,7 +80,7 @@ func (r *Resolver) ResolveAllIP(host string) (ips []netip.Addr, err error) {
ch <- ip
}()
ips, err = r.resolveIP(host, D.TypeA)
ips, err = r.resolveIP(ctx, host, D.TypeA)
select {
case ipv6s, open := <-ch:
@ -95,17 +95,17 @@ func (r *Resolver) ResolveAllIP(host string) (ips []netip.Addr, err error) {
return ips, nil
}
func (r *Resolver) ResolveAllIPv4(host string) (ips []netip.Addr, err error) {
return r.resolveIP(host, D.TypeA)
func (r *Resolver) LookupIPv4(ctx context.Context, host string) (ips []netip.Addr, err error) {
return r.resolveIP(ctx, host, D.TypeA)
}
func (r *Resolver) ResolveAllIPv6(host string) (ips []netip.Addr, err error) {
return r.resolveIP(host, D.TypeAAAA)
func (r *Resolver) LookupIPv6(ctx context.Context, host string) (ips []netip.Addr, err error) {
return r.resolveIP(ctx, host, D.TypeAAAA)
}
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA
func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) {
if ips, err := r.ResolveAllIPPrimaryIPv4(host); err == nil {
func (r *Resolver) ResolveIP(ctx context.Context, host string) (ip netip.Addr, err error) {
if ips, err := r.LookupIPPrimaryIPv4(ctx, host); err == nil {
return ips[rand.Intn(len(ips))], nil
} else {
return netip.Addr{}, err
@ -113,8 +113,8 @@ func (r *Resolver) ResolveIP(host string) (ip netip.Addr, err error) {
}
// ResolveIPv4 request with TypeA
func (r *Resolver) ResolveIPv4(host string) (ip netip.Addr, err error) {
if ips, err := r.ResolveAllIPv4(host); err == nil {
func (r *Resolver) ResolveIPv4(ctx context.Context, host string) (ip netip.Addr, err error) {
if ips, err := r.LookupIPv4(ctx, host); err == nil {
return ips[rand.Intn(len(ips))], nil
} else {
return netip.Addr{}, err
@ -122,8 +122,8 @@ func (r *Resolver) ResolveIPv4(host string) (ip netip.Addr, err error) {
}
// ResolveIPv6 request with TypeAAAA
func (r *Resolver) ResolveIPv6(host string) (ip netip.Addr, err error) {
if ips, err := r.ResolveAllIPv6(host); err == nil {
func (r *Resolver) ResolveIPv6(ctx context.Context, host string) (ip netip.Addr, err error) {
if ips, err := r.LookupIPv6(ctx, host); err == nil {
return ips[rand.Intn(len(ips))], nil
} else {
return netip.Addr{}, err
@ -305,7 +305,7 @@ func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err er
return
}
func (r *Resolver) resolveIP(host string, dnsType uint16) (ips []netip.Addr, err error) {
func (r *Resolver) resolveIP(ctx context.Context, host string, dnsType uint16) (ips []netip.Addr, err error) {
ip, err := netip.ParseAddr(host)
if err == nil {
isIPv4 := ip.Is4()
@ -321,7 +321,7 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ips []netip.Addr, err
query := &D.Msg{}
query.SetQuestion(D.Fqdn(host), dnsType)
msg, err := r.Exchange(query)
msg, err := r.ExchangeContext(ctx, query)
if err != nil {
return []netip.Addr{}, err
}

View File

@ -1,6 +1,7 @@
package dns
import (
stdContext "context"
"errors"
"net"
@ -25,7 +26,7 @@ type Server struct {
// ServeDNS implement D.Handler ServeDNS
func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) {
msg, err := handlerWithContext(s.handler, r)
msg, err := handlerWithContext(stdContext.Background(), s.handler, r)
if err != nil {
D.HandleFailed(w, r)
return
@ -34,12 +35,12 @@ func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) {
w.WriteMsg(msg)
}
func handlerWithContext(handler handler, msg *D.Msg) (*D.Msg, error) {
func handlerWithContext(stdCtx stdContext.Context, handler handler, msg *D.Msg) (*D.Msg, error) {
if len(msg.Question) == 0 {
return nil, errors.New("at least one question is required")
}
ctx := context.NewDNSContext(msg)
ctx := context.NewDNSContext(stdCtx, msg)
return handler(ctx, msg)
}