Feature: add dhcp type dns client (#1509)

This commit is contained in:
Kr328
2021-09-06 23:07:34 +08:00
committed by GitHub
parent a2d59d6ef5
commit a5b950a779
26 changed files with 759 additions and 288 deletions

View File

@ -2,6 +2,7 @@ package dns
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
@ -34,22 +35,16 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err
}
}
d, err := dialer.Dialer()
network := "udp"
if strings.HasPrefix(c.Client.Net, "tcp") {
network = "tcp"
}
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), c.port))
if err != nil {
return nil, err
}
if ip != nil && ip.IsGlobalUnicast() && dialer.DialHook != nil {
network := "udp"
if strings.HasPrefix(c.Client.Net, "tcp") {
network = "tcp"
}
if err := dialer.DialHook(d, network, ip); err != nil {
return nil, err
}
}
c.Client.Dialer = d
defer conn.Close()
// miekg/dns ExchangeContext doesn't respond to context cancel.
// this is a workaround
@ -59,7 +54,17 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err
}
ch := make(chan result, 1)
go func() {
msg, _, err := c.Client.Exchange(m, net.JoinHostPort(ip.String(), c.port))
if strings.HasSuffix(c.Client.Net, "tls") {
conn = tls.Client(conn, c.Client.TLSConfig)
}
msg, _, err = c.Client.ExchangeWithConn(m, &D.Conn{
Conn: conn,
UDPSize: c.Client.UDPSize,
TsigSecret: c.Client.TsigSecret,
TsigProvider: c.Client.TsigProvider,
})
ch <- result{msg, err}
}()

144
dns/dhcp.go Normal file
View File

@ -0,0 +1,144 @@
package dns
import (
"bytes"
"context"
"net"
"sync"
"time"
"github.com/Dreamacro/clash/component/dhcp"
"github.com/Dreamacro/clash/component/iface"
"github.com/Dreamacro/clash/component/resolver"
D "github.com/miekg/dns"
)
const (
IfaceTTL = time.Second * 20
DHCPTTL = time.Hour
DHCPTimeout = time.Minute
)
type dhcpClient struct {
ifaceName string
lock sync.Mutex
ifaceInvalidate time.Time
dnsInvalidate time.Time
ifaceAddr *net.IPNet
done chan struct{}
resolver *Resolver
err error
}
func (d *dhcpClient) Exchange(m *D.Msg) (msg *D.Msg, err error) {
ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout)
defer cancel()
return d.ExchangeContext(ctx, m)
}
func (d *dhcpClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
res, err := d.resolve(ctx)
if err != nil {
return nil, err
}
return res.ExchangeContext(ctx, m)
}
func (d *dhcpClient) resolve(ctx context.Context) (*Resolver, error) {
d.lock.Lock()
invalidated, err := d.invalidate()
if err != nil {
d.err = err
} else if invalidated {
done := make(chan struct{})
d.done = done
go func() {
ctx, cancel := context.WithTimeout(context.Background(), DHCPTimeout)
defer cancel()
var res *Resolver
dns, err := dhcp.ResolveDNSFromDHCP(ctx, d.ifaceName)
if err == nil {
nameserver := make([]NameServer, 0, len(dns))
for _, d := range dns {
nameserver = append(nameserver, NameServer{Addr: net.JoinHostPort(d.String(), "53")})
}
res = NewResolver(Config{
Main: nameserver,
})
}
d.lock.Lock()
defer d.lock.Unlock()
close(done)
d.done = nil
d.resolver = res
d.err = err
}()
}
d.lock.Unlock()
for {
d.lock.Lock()
res, err, done := d.resolver, d.err, d.done
d.lock.Unlock()
// initializing
if res == nil && err == nil {
select {
case <-done:
continue
case <-ctx.Done():
return nil, ctx.Err()
}
}
// dirty return
return res, err
}
}
func (d *dhcpClient) invalidate() (bool, error) {
if time.Now().Before(d.ifaceInvalidate) {
return false, nil
}
d.ifaceInvalidate = time.Now().Add(IfaceTTL)
ifaceObj, err := iface.ResolveInterface(d.ifaceName)
if err != nil {
return false, err
}
addr, err := ifaceObj.PickIPv4Addr(nil)
if err != nil {
return false, err
}
if time.Now().Before(d.dnsInvalidate) && d.ifaceAddr.IP.Equal(addr.IP) && bytes.Equal(d.ifaceAddr.Mask, addr.Mask) {
return false, nil
}
d.dnsInvalidate = time.Now().Add(DHCPTTL)
d.ifaceAddr = addr
return d.done == nil, nil
}
func newDHCPClient(ifaceName string) *dhcpClient {
return &dhcpClient{ifaceName: ifaceName}
}

View File

@ -87,6 +87,11 @@ func (r *Resolver) shouldIPFallback(ip net.IP) bool {
// Exchange a batch of dns request, and it use cache
func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) {
return r.ExchangeContext(context.Background(), m)
}
// ExchangeContext a batch of dns request with context.Context, and it use cache
func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
if len(m.Question) == 0 {
return nil, errors.New("should have one question at least")
}
@ -98,17 +103,17 @@ func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) {
msg = cache.(*D.Msg).Copy()
if expireTime.Before(now) {
setMsgTTL(msg, uint32(1)) // Continue fetch
go r.exchangeWithoutCache(m)
go r.exchangeWithoutCache(ctx, m)
} else {
setMsgTTL(msg, uint32(time.Until(expireTime).Seconds()))
}
return
}
return r.exchangeWithoutCache(m)
return r.exchangeWithoutCache(ctx, m)
}
// ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache
func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) {
func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
q := m.Question[0]
ret, err, shared := r.group.Do(q.String(), func() (result interface{}, err error) {
@ -124,13 +129,13 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) {
isIPReq := isIPRequest(q)
if isIPReq {
return r.ipExchange(m)
return r.ipExchange(ctx, m)
}
if matched := r.matchPolicy(m); len(matched) != 0 {
return r.batchExchange(matched, m)
return r.batchExchange(ctx, matched, m)
}
return r.batchExchange(r.main, m)
return r.batchExchange(ctx, r.main, m)
})
if err == nil {
@ -143,8 +148,8 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) {
return
}
func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {
fast, ctx := picker.WithTimeout(context.Background(), resolver.DefaultDNSTimeout)
func (r *Resolver) batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {
fast, ctx := picker.WithTimeout(ctx, resolver.DefaultDNSTimeout)
for _, client := range clients {
r := client
fast.Go(func() (interface{}, error) {
@ -209,21 +214,21 @@ func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
return false
}
func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
if matched := r.matchPolicy(m); len(matched) != 0 {
res := <-r.asyncExchange(matched, m)
res := <-r.asyncExchange(ctx, matched, m)
return res.Msg, res.Error
}
onlyFallback := r.shouldOnlyQueryFallback(m)
if onlyFallback {
res := <-r.asyncExchange(r.fallback, m)
res := <-r.asyncExchange(ctx, r.fallback, m)
return res.Msg, res.Error
}
msgCh := r.asyncExchange(r.main, m)
msgCh := r.asyncExchange(ctx, r.main, m)
if r.fallback == nil { // directly return if no fallback servers are available
res := <-msgCh
@ -231,7 +236,7 @@ func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
return
}
fallbackMsg := r.asyncExchange(r.fallback, m)
fallbackMsg := r.asyncExchange(ctx, r.fallback, m)
res := <-msgCh
if res.Error == nil {
if ips := msgToIP(res.Msg); len(ips) != 0 {
@ -287,10 +292,10 @@ func (r *Resolver) msgToDomain(msg *D.Msg) string {
return ""
}
func (r *Resolver) asyncExchange(client []dnsClient, msg *D.Msg) <-chan *result {
func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D.Msg) <-chan *result {
ch := make(chan *result, 1)
go func() {
res, err := r.batchExchange(client, msg)
res, err := r.batchExchange(ctx, client, msg)
ch <- &result{Msg: res, Error: err}
}()
return ch

View File

@ -117,9 +117,13 @@ func isIPRequest(q D.Question) bool {
func transform(servers []NameServer, resolver *Resolver) []dnsClient {
ret := []dnsClient{}
for _, s := range servers {
if s.Net == "https" {
switch s.Net {
case "https":
ret = append(ret, newDoHClient(s.Addr, resolver))
continue
case "dhcp":
ret = append(ret, newDHCPClient(s.Addr))
continue
}
host, port, _ := net.SplitHostPort(s.Addr)