Refactor: make inbound request contextual
This commit is contained in:
@ -8,26 +8,27 @@ import (
|
||||
"github.com/Dreamacro/clash/common/cache"
|
||||
"github.com/Dreamacro/clash/component/fakeip"
|
||||
"github.com/Dreamacro/clash/component/trie"
|
||||
"github.com/Dreamacro/clash/context"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
|
||||
D "github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type handler func(r *D.Msg) (*D.Msg, error)
|
||||
type handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
|
||||
type middleware func(next handler) handler
|
||||
|
||||
func withHosts(hosts *trie.DomainTrie) middleware {
|
||||
return func(next handler) handler {
|
||||
return func(r *D.Msg) (*D.Msg, error) {
|
||||
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
|
||||
q := r.Question[0]
|
||||
|
||||
if !isIPRequest(q) {
|
||||
return next(r)
|
||||
return next(ctx, r)
|
||||
}
|
||||
|
||||
record := hosts.Search(strings.TrimRight(q.Name, "."))
|
||||
if record == nil {
|
||||
return next(r)
|
||||
return next(ctx, r)
|
||||
}
|
||||
|
||||
ip := record.Data.(net.IP)
|
||||
@ -46,9 +47,10 @@ func withHosts(hosts *trie.DomainTrie) middleware {
|
||||
|
||||
msg.Answer = []D.RR{rr}
|
||||
} else {
|
||||
return next(r)
|
||||
return next(ctx, r)
|
||||
}
|
||||
|
||||
ctx.SetType(context.DNSTypeHost)
|
||||
msg.SetRcode(r, D.RcodeSuccess)
|
||||
msg.Authoritative = true
|
||||
msg.RecursionAvailable = true
|
||||
@ -60,14 +62,14 @@ func withHosts(hosts *trie.DomainTrie) middleware {
|
||||
|
||||
func withMapping(mapping *cache.LruCache) middleware {
|
||||
return func(next handler) handler {
|
||||
return func(r *D.Msg) (*D.Msg, error) {
|
||||
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
|
||||
q := r.Question[0]
|
||||
|
||||
if !isIPRequest(q) {
|
||||
return next(r)
|
||||
return next(ctx, r)
|
||||
}
|
||||
|
||||
msg, err := next(r)
|
||||
msg, err := next(ctx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -99,12 +101,12 @@ func withMapping(mapping *cache.LruCache) middleware {
|
||||
|
||||
func withFakeIP(fakePool *fakeip.Pool) middleware {
|
||||
return func(next handler) handler {
|
||||
return func(r *D.Msg) (*D.Msg, error) {
|
||||
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
|
||||
q := r.Question[0]
|
||||
|
||||
host := strings.TrimRight(q.Name, ".")
|
||||
if fakePool.LookupHost(host) {
|
||||
return next(r)
|
||||
return next(ctx, r)
|
||||
}
|
||||
|
||||
switch q.Qtype {
|
||||
@ -113,7 +115,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
|
||||
}
|
||||
|
||||
if q.Qtype != D.TypeA {
|
||||
return next(r)
|
||||
return next(ctx, r)
|
||||
}
|
||||
|
||||
rr := &D.A{}
|
||||
@ -123,6 +125,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
|
||||
msg := r.Copy()
|
||||
msg.Answer = []D.RR{rr}
|
||||
|
||||
ctx.SetType(context.DNSTypeFakeIP)
|
||||
setMsgTTL(msg, 1)
|
||||
msg.SetRcode(r, D.RcodeSuccess)
|
||||
msg.Authoritative = true
|
||||
@ -134,7 +137,8 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
|
||||
}
|
||||
|
||||
func withResolver(resolver *Resolver) handler {
|
||||
return func(r *D.Msg) (*D.Msg, error) {
|
||||
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
|
||||
ctx.SetType(context.DNSTypeRaw)
|
||||
q := r.Question[0]
|
||||
|
||||
// return a empty AAAA msg when ipv6 disabled
|
||||
|
@ -212,7 +212,7 @@ func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
|
||||
fallbackMsg := r.asyncExchange(r.fallback, m)
|
||||
res := <-msgCh
|
||||
if res.Error == nil {
|
||||
if ips := r.msgToIP(res.Msg); len(ips) != 0 {
|
||||
if ips := msgToIP(res.Msg); len(ips) != 0 {
|
||||
if !r.shouldIPFallback(ips[0]) {
|
||||
msg = res.Msg // no need to wait for fallback result
|
||||
err = res.Error
|
||||
@ -247,7 +247,7 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ips := r.msgToIP(msg)
|
||||
ips := msgToIP(msg)
|
||||
ipLength := len(ips)
|
||||
if ipLength == 0 {
|
||||
return nil, resolver.ErrIPNotFound
|
||||
@ -257,21 +257,6 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Resolver) msgToIP(msg *D.Msg) []net.IP {
|
||||
ips := []net.IP{}
|
||||
|
||||
for _, answer := range msg.Answer {
|
||||
switch ans := answer.(type) {
|
||||
case *D.AAAA:
|
||||
ips = append(ips, ans.AAAA)
|
||||
case *D.A:
|
||||
ips = append(ips, ans.A)
|
||||
}
|
||||
}
|
||||
|
||||
return ips
|
||||
}
|
||||
|
||||
func (r *Resolver) msgToDomain(msg *D.Msg) string {
|
||||
if len(msg.Question) > 0 {
|
||||
return strings.TrimRight(msg.Question[0].Name, ".")
|
||||
|
@ -1,9 +1,11 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/Dreamacro/clash/common/sockopt"
|
||||
"github.com/Dreamacro/clash/context"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
|
||||
D "github.com/miekg/dns"
|
||||
@ -21,21 +23,25 @@ type Server struct {
|
||||
handler handler
|
||||
}
|
||||
|
||||
// ServeDNS implement D.Handler ServeDNS
|
||||
func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
D.HandleFailed(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := s.handler(r)
|
||||
msg, err := handlerWithContext(s.handler, r)
|
||||
if err != nil {
|
||||
D.HandleFailed(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteMsg(msg)
|
||||
}
|
||||
|
||||
func handlerWithContext(handler handler, msg *D.Msg) (*D.Msg, error) {
|
||||
if len(msg.Question) == 0 {
|
||||
return nil, errors.New("at least one question is required")
|
||||
}
|
||||
|
||||
ctx := context.NewDNSContext(msg)
|
||||
return handler(ctx, msg)
|
||||
}
|
||||
|
||||
func (s *Server) setHandler(handler handler) {
|
||||
s.handler = handler
|
||||
}
|
||||
|
15
dns/util.go
15
dns/util.go
@ -153,3 +153,18 @@ func handleMsgWithEmptyAnswer(r *D.Msg) *D.Msg {
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func msgToIP(msg *D.Msg) []net.IP {
|
||||
ips := []net.IP{}
|
||||
|
||||
for _, answer := range msg.Answer {
|
||||
switch ans := answer.(type) {
|
||||
case *D.AAAA:
|
||||
ips = append(ips, ans.AAAA)
|
||||
case *D.A:
|
||||
ips = append(ips, ans.A)
|
||||
}
|
||||
}
|
||||
|
||||
return ips
|
||||
}
|
||||
|
Reference in New Issue
Block a user