Feature: move hosts to the top

This commit is contained in:
Dreamacro
2019-09-11 17:00:55 +08:00
parent 16e3090ee8
commit 96a4abf46c
6 changed files with 164 additions and 122 deletions

View File

@ -1,8 +1,6 @@
package dns
import (
"fmt"
"net"
"strings"
"github.com/Dreamacro/clash/component/fakeip"
@ -12,34 +10,40 @@ import (
)
type handler func(w D.ResponseWriter, r *D.Msg)
type middleware func(next handler) handler
func withFakeIP(pool *fakeip.Pool) handler {
return func(w D.ResponseWriter, r *D.Msg) {
q := r.Question[0]
host := strings.TrimRight(q.Name, ".")
func withFakeIP(fakePool *fakeip.Pool) middleware {
return func(next handler) handler {
return func(w D.ResponseWriter, r *D.Msg) {
q := r.Question[0]
if q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA {
next(w, r)
return
}
rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
ip := pool.Lookup(host)
rr.A = ip
msg := r.Copy()
msg.Answer = []D.RR{rr}
host := strings.TrimRight(q.Name, ".")
setMsgTTL(msg, 1)
msg.SetReply(r)
w.WriteMsg(msg)
return
rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
ip := fakePool.Lookup(host)
rr.A = ip
msg := r.Copy()
msg.Answer = []D.RR{rr}
setMsgTTL(msg, 1)
msg.SetReply(r)
w.WriteMsg(msg)
return
}
}
}
func withResolver(resolver *Resolver) handler {
return func(w D.ResponseWriter, r *D.Msg) {
msg, err := resolver.Exchange(r)
if err != nil {
q := r.Question[0]
qString := fmt.Sprintf("%s %s %s", q.Name, D.Class(q.Qclass).String(), D.Type(q.Qtype).String())
log.Debugln("[DNS Server] Exchange %s failed: %v", qString, err)
log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err)
D.HandleFailed(w, r)
return
}
@ -49,64 +53,23 @@ func withResolver(resolver *Resolver) handler {
}
}
func withHost(resolver *Resolver, next handler) handler {
hosts := resolver.hosts
if hosts == nil {
panic("dns/withHost: hosts should not be nil")
func compose(middlewares []middleware, endpoint handler) handler {
length := len(middlewares)
h := endpoint
for i := length - 1; i >= 0; i-- {
middleware := middlewares[i]
h = middleware(h)
}
return func(w D.ResponseWriter, r *D.Msg) {
q := r.Question[0]
if q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA {
next(w, r)
return
}
domain := strings.TrimRight(q.Name, ".")
host := hosts.Search(domain)
if host == nil {
next(w, r)
return
}
ip := host.Data.(net.IP)
if q.Qtype == D.TypeAAAA && ip.To16() == nil {
next(w, r)
return
} else if q.Qtype == D.TypeA && ip.To4() == nil {
next(w, r)
return
}
var rr D.RR
if q.Qtype == D.TypeAAAA {
record := &D.AAAA{}
record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
record.AAAA = ip
rr = record
} else {
record := &D.A{}
record.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
record.A = ip
rr = record
}
msg := r.Copy()
msg.Answer = []D.RR{rr}
msg.SetReply(r)
w.WriteMsg(msg)
return
}
return h
}
func newHandler(resolver *Resolver) handler {
middlewares := []middleware{}
if resolver.IsFakeIP() {
return withFakeIP(resolver.pool)
middlewares = append(middlewares, withFakeIP(resolver.pool))
}
if resolver.hosts != nil {
return withHost(resolver, withResolver(resolver))
}
return withResolver(resolver)
return compose(middlewares, withResolver(resolver))
}