Refactor: DomainTrie use generics
This commit is contained in:
@ -70,13 +70,13 @@ type fallbackDomainFilter interface {
|
||||
}
|
||||
|
||||
type domainFilter struct {
|
||||
tree *trie.DomainTrie
|
||||
tree *trie.DomainTrie[bool]
|
||||
}
|
||||
|
||||
func NewDomainFilter(domains []string) *domainFilter {
|
||||
df := domainFilter{tree: trie.New()}
|
||||
df := domainFilter{tree: trie.New[bool]()}
|
||||
for _, domain := range domains {
|
||||
df.tree.Insert(domain, "")
|
||||
df.tree.Insert(domain, true)
|
||||
}
|
||||
return &df
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -20,7 +21,7 @@ type (
|
||||
middleware func(next handler) handler
|
||||
)
|
||||
|
||||
func withHosts(hosts *trie.DomainTrie) middleware {
|
||||
func withHosts(hosts *trie.DomainTrie[netip.Addr]) middleware {
|
||||
return func(next handler) handler {
|
||||
return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
|
||||
q := r.Question[0]
|
||||
@ -34,19 +35,19 @@ func withHosts(hosts *trie.DomainTrie) middleware {
|
||||
return next(ctx, r)
|
||||
}
|
||||
|
||||
ip := record.Data.(net.IP)
|
||||
ip := record.Data
|
||||
msg := r.Copy()
|
||||
|
||||
if v4 := ip.To4(); v4 != nil && q.Qtype == D.TypeA {
|
||||
if ip.Is4() && q.Qtype == D.TypeA {
|
||||
rr := &D.A{}
|
||||
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
|
||||
rr.A = v4
|
||||
rr.A = ip.AsSlice()
|
||||
|
||||
msg.Answer = []D.RR{rr}
|
||||
} else if v6 := ip.To16(); v6 != nil && q.Qtype == D.TypeAAAA {
|
||||
} else if ip.Is6() && q.Qtype == D.TypeAAAA {
|
||||
rr := &D.AAAA{}
|
||||
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
|
||||
rr.AAAA = v6
|
||||
rr.AAAA = ip.AsSlice()
|
||||
|
||||
msg.Answer = []D.RR{rr}
|
||||
} else {
|
||||
|
30
dns/policy.go
Normal file
30
dns/policy.go
Normal file
@ -0,0 +1,30 @@
|
||||
package dns
|
||||
|
||||
type Policy struct {
|
||||
data []dnsClient
|
||||
}
|
||||
|
||||
func (p *Policy) GetData() []dnsClient {
|
||||
return p.data
|
||||
}
|
||||
|
||||
func (p *Policy) Compare(p2 *Policy) int {
|
||||
if p2 == nil {
|
||||
return 1
|
||||
}
|
||||
l1 := len(p.data)
|
||||
l2 := len(p2.data)
|
||||
if l1 == l2 {
|
||||
return 0
|
||||
}
|
||||
if l1 > l2 {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func NewPolicy(data []dnsClient) *Policy {
|
||||
return &Policy{
|
||||
data: data,
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -33,14 +34,14 @@ type result struct {
|
||||
|
||||
type Resolver struct {
|
||||
ipv6 bool
|
||||
hosts *trie.DomainTrie
|
||||
hosts *trie.DomainTrie[netip.Addr]
|
||||
main []dnsClient
|
||||
fallback []dnsClient
|
||||
fallbackDomainFilters []fallbackDomainFilter
|
||||
fallbackIPFilters []fallbackIPFilter
|
||||
group singleflight.Group
|
||||
lruCache *cache.LruCache[string, *D.Msg]
|
||||
policy *trie.DomainTrie
|
||||
policy *trie.DomainTrie[*Policy]
|
||||
proxyServer []dnsClient
|
||||
}
|
||||
|
||||
@ -194,7 +195,8 @@ func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient {
|
||||
return nil
|
||||
}
|
||||
|
||||
return record.Data.([]dnsClient)
|
||||
p := record.Data
|
||||
return p.GetData()
|
||||
}
|
||||
|
||||
func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
|
||||
@ -329,7 +331,7 @@ type Config struct {
|
||||
EnhancedMode C.DNSMode
|
||||
FallbackFilter FallbackFilter
|
||||
Pool *fakeip.Pool
|
||||
Hosts *trie.DomainTrie
|
||||
Hosts *trie.DomainTrie[netip.Addr]
|
||||
Policy map[string]NameServer
|
||||
}
|
||||
|
||||
@ -355,9 +357,9 @@ func NewResolver(config Config) *Resolver {
|
||||
}
|
||||
|
||||
if len(config.Policy) != 0 {
|
||||
r.policy = trie.New()
|
||||
r.policy = trie.New[*Policy]()
|
||||
for domain, nameserver := range config.Policy {
|
||||
r.policy.Insert(domain, transform([]NameServer{nameserver}, defaultResolver))
|
||||
r.policy.Insert(domain, NewPolicy(transform([]NameServer{nameserver}, defaultResolver)))
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user