Refactor: make inbound request contextual

This commit is contained in:
Dreamacro
2021-01-23 14:49:46 +08:00
parent 35925cb3da
commit f4de055aa1
19 changed files with 302 additions and 125 deletions

View File

@ -8,26 +8,27 @@ import (
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/component/fakeip"
"github.com/Dreamacro/clash/component/trie"
"github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log"
D "github.com/miekg/dns"
)
type handler func(r *D.Msg) (*D.Msg, error)
type handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
type middleware func(next handler) handler
func withHosts(hosts *trie.DomainTrie) middleware {
return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
if !isIPRequest(q) {
return next(r)
return next(ctx, r)
}
record := hosts.Search(strings.TrimRight(q.Name, "."))
if record == nil {
return next(r)
return next(ctx, r)
}
ip := record.Data.(net.IP)
@ -46,9 +47,10 @@ func withHosts(hosts *trie.DomainTrie) middleware {
msg.Answer = []D.RR{rr}
} else {
return next(r)
return next(ctx, r)
}
ctx.SetType(context.DNSTypeHost)
msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true
msg.RecursionAvailable = true
@ -60,14 +62,14 @@ func withHosts(hosts *trie.DomainTrie) middleware {
func withMapping(mapping *cache.LruCache) middleware {
return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
if !isIPRequest(q) {
return next(r)
return next(ctx, r)
}
msg, err := next(r)
msg, err := next(ctx, r)
if err != nil {
return nil, err
}
@ -99,12 +101,12 @@ func withMapping(mapping *cache.LruCache) middleware {
func withFakeIP(fakePool *fakeip.Pool) middleware {
return func(next handler) handler {
return func(r *D.Msg) (*D.Msg, error) {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
q := r.Question[0]
host := strings.TrimRight(q.Name, ".")
if fakePool.LookupHost(host) {
return next(r)
return next(ctx, r)
}
switch q.Qtype {
@ -113,7 +115,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
}
if q.Qtype != D.TypeA {
return next(r)
return next(ctx, r)
}
rr := &D.A{}
@ -123,6 +125,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
msg := r.Copy()
msg.Answer = []D.RR{rr}
ctx.SetType(context.DNSTypeFakeIP)
setMsgTTL(msg, 1)
msg.SetRcode(r, D.RcodeSuccess)
msg.Authoritative = true
@ -134,7 +137,8 @@ func withFakeIP(fakePool *fakeip.Pool) middleware {
}
func withResolver(resolver *Resolver) handler {
return func(r *D.Msg) (*D.Msg, error) {
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
ctx.SetType(context.DNSTypeRaw)
q := r.Question[0]
// return a empty AAAA msg when ipv6 disabled

View File

@ -212,7 +212,7 @@ func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
fallbackMsg := r.asyncExchange(r.fallback, m)
res := <-msgCh
if res.Error == nil {
if ips := r.msgToIP(res.Msg); len(ips) != 0 {
if ips := msgToIP(res.Msg); len(ips) != 0 {
if !r.shouldIPFallback(ips[0]) {
msg = res.Msg // no need to wait for fallback result
err = res.Error
@ -247,7 +247,7 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error)
return nil, err
}
ips := r.msgToIP(msg)
ips := msgToIP(msg)
ipLength := len(ips)
if ipLength == 0 {
return nil, resolver.ErrIPNotFound
@ -257,21 +257,6 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error)
return
}
func (r *Resolver) msgToIP(msg *D.Msg) []net.IP {
ips := []net.IP{}
for _, answer := range msg.Answer {
switch ans := answer.(type) {
case *D.AAAA:
ips = append(ips, ans.AAAA)
case *D.A:
ips = append(ips, ans.A)
}
}
return ips
}
func (r *Resolver) msgToDomain(msg *D.Msg) string {
if len(msg.Question) > 0 {
return strings.TrimRight(msg.Question[0].Name, ".")

View File

@ -1,9 +1,11 @@
package dns
import (
"errors"
"net"
"github.com/Dreamacro/clash/common/sockopt"
"github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log"
D "github.com/miekg/dns"
@ -21,21 +23,25 @@ type Server struct {
handler handler
}
// ServeDNS implement D.Handler ServeDNS
func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) {
if len(r.Question) == 0 {
D.HandleFailed(w, r)
return
}
msg, err := s.handler(r)
msg, err := handlerWithContext(s.handler, r)
if err != nil {
D.HandleFailed(w, r)
return
}
w.WriteMsg(msg)
}
func handlerWithContext(handler handler, msg *D.Msg) (*D.Msg, error) {
if len(msg.Question) == 0 {
return nil, errors.New("at least one question is required")
}
ctx := context.NewDNSContext(msg)
return handler(ctx, msg)
}
func (s *Server) setHandler(handler handler) {
s.handler = handler
}

View File

@ -153,3 +153,18 @@ func handleMsgWithEmptyAnswer(r *D.Msg) *D.Msg {
return msg
}
func msgToIP(msg *D.Msg) []net.IP {
ips := []net.IP{}
for _, answer := range msg.Answer {
switch ans := answer.(type) {
case *D.AAAA:
ips = append(ips, ans.AAAA)
case *D.A:
ips = append(ips, ans.A)
}
}
return ips
}