feat: nameserver-policy support domain-set, reduce domain-set memory usage

This commit is contained in:
Skyxim 2023-03-29 13:14:02 +08:00
parent 56e525114d
commit d52748165f
8 changed files with 256 additions and 35 deletions

8
common/utils/strings.go Normal file
View File

@ -0,0 +1,8 @@
package utils
func Reverse(s string) string {
a := []rune(s)
for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 {
a[i], a[j] = a[j], a[i]
}
return string(a)
}

158
component/trie/sskv.go Normal file
View File

@ -0,0 +1,158 @@
// Package succinct provides several succinct data types.
// Modify from https://github.com/openacid/succinct/sskv.go
package trie
import (
"sort"
"strings"
"github.com/Dreamacro/clash/common/utils"
"github.com/openacid/low/bitmap"
)
const (
complexWildcardByte = byte('+')
wildcardByte = byte('*')
domainStepByte = byte('.')
)
type Set struct {
leaves, labelBitmap []uint64
labels []byte
ranks, selects []int32
}
// NewSet creates a new *Set struct, from a slice of sorted strings.
func NewDomainTrieSet(keys []string) *Set {
filter := make(map[string]struct{}, len(keys))
reserveDomains := make([]string, 0, len(keys))
for _, key := range keys {
items, ok := ValidAndSplitDomain(key)
if !ok {
continue
}
if items[0] == complexWildcard {
domain := strings.Join(items[1:], domainStep)
reserveDomain := utils.Reverse(domain)
filter[reserveDomain] = struct{}{}
reserveDomains = append(reserveDomains, reserveDomain)
}
domain := strings.Join(items, domainStep)
reserveDomain := utils.Reverse(domain)
filter[reserveDomain] = struct{}{}
reserveDomains = append(reserveDomains, reserveDomain)
}
sort.Strings(reserveDomains)
keys=reserveDomains
ss := &Set{}
lIdx := 0
type qElt struct{ s, e, col int }
queue := []qElt{{0, len(keys), 0}}
for i := 0; i < len(queue); i++ {
elt := queue[i]
if elt.col == len(keys[elt.s]) {
elt.s++
// a leaf node
setBit(&ss.leaves, i, 1)
}
for j := elt.s; j < elt.e; {
frm := j
for ; j < elt.e && keys[j][elt.col] == keys[frm][elt.col]; j++ {
}
queue = append(queue, qElt{frm, j, elt.col + 1})
ss.labels = append(ss.labels, keys[frm][elt.col])
setBit(&ss.labelBitmap, lIdx, 0)
lIdx++
}
setBit(&ss.labelBitmap, lIdx, 1)
lIdx++
}
ss.init()
return ss
}
// Has query for a key and return whether it presents in the Set.
func (ss *Set) Has(key string) bool {
key=utils.Reverse(key)
// no more labels in this node
// skip character matching
// go to next level
nodeId, bmIdx := 0, 0
for i := 0; i < len(key); i++ {
c := key[i]
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return false
}
// handle wildcard for domain
if ss.labels[bmIdx-nodeId] == complexWildcardByte {
return true
} else if ss.labels[bmIdx-nodeId] == wildcardByte {
j := i
for ; j < len(key); j++ {
if key[j] == domainStepByte {
i = j
goto END
}
}
return true
} else if ss.labels[bmIdx-nodeId] == c {
break
}
}
END:
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
}
if getBit(ss.leaves, nodeId) != 0 {
return true
} else {
return false
}
}
func setBit(bm *[]uint64, i int, v int) {
for i>>6 >= len(*bm) {
*bm = append(*bm, 0)
}
(*bm)[i>>6] |= uint64(v) << uint(i&63)
}
func getBit(bm []uint64, i int) uint64 {
return bm[i>>6] & (1 << uint(i&63))
}
// init builds pre-calculated cache to speed up rank() and select()
func (ss *Set) init() {
ss.selects, ss.ranks = bitmap.IndexSelect32R64(ss.labelBitmap)
}
// countZeros counts the number of "0" in a bitmap before the i-th bit(excluding
// the i-th bit) on behalf of rank index.
// E.g.:
//
// countZeros("010010", 4) == 3
// // 012345
func countZeros(bm []uint64, ranks []int32, i int) int {
a, _ := bitmap.Rank64(bm, ranks, int32(i))
return i - int(a)
}
// selectIthOne returns the index of the i-th "1" in a bitmap, on behalf of rank
// and select indexes.
// E.g.:
//
// selectIthOne("010010", 1) == 4
// // 012345
func selectIthOne(bm []uint64, ranks, selects []int32, i int) int {
a, _ := bitmap.Select32R64(bm, selects, ranks, int32(i))
return int(a)
}

View File

@ -490,7 +490,7 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) {
}
config.Hosts = hosts
dnsCfg, err := parseDNS(rawCfg, hosts, rules)
dnsCfg, err := parseDNS(rawCfg, hosts, rules, ruleProviders)
if err != nil {
return nil, err
}
@ -983,7 +983,7 @@ func parsePureDNSServer(server string) string {
}
}
}
func parseNameServerPolicy(nsPolicy map[string]any, preferH3 bool) (map[string][]dns.NameServer, error) {
func parseNameServerPolicy(nsPolicy map[string]any, ruleProviders map[string]providerTypes.RuleProvider, preferH3 bool) (map[string][]dns.NameServer, error) {
policy := map[string][]dns.NameServer{}
updatedPolicy := make(map[string]interface{})
re := regexp.MustCompile(`[a-zA-Z0-9\-]+\.[a-zA-Z]{2,}(\.[a-zA-Z]{2,})?`)
@ -998,6 +998,14 @@ func parseNameServerPolicy(nsPolicy map[string]any, preferH3 bool) (map[string][
newKey := "geosite:" + subkey
updatedPolicy[newKey] = v
}
} else if strings.Contains(k, "domain-set:") {
subkeys := strings.Split(k, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
for _, subkey := range subkeys {
newKey := "domain-set:" + subkey
updatedPolicy[newKey] = v
}
} else if re.MatchString(k) {
subkeys := strings.Split(k, ",")
for _, subkey := range subkeys {
@ -1021,6 +1029,14 @@ func parseNameServerPolicy(nsPolicy map[string]any, preferH3 bool) (map[string][
if _, valid := trie.ValidAndSplitDomain(domain); !valid {
return nil, fmt.Errorf("DNS ResoverRule invalid domain: %s", domain)
}
if strings.HasPrefix(domain, "domain-set:") {
domainSetName := domain[11:]
if provider, ok := ruleProviders[domainSetName]; !ok {
return nil, fmt.Errorf("not found domain-set: %s", domainSetName)
} else if provider.Behavior() != providerTypes.Domain {
return nil, fmt.Errorf("rule provider type error, except domain,actual %s", provider.Behavior())
}
}
policy[domain] = nameservers
}
@ -1077,7 +1093,7 @@ func parseFallbackGeoSite(countries []string, rules []C.Rule) ([]*router.DomainM
return sites, nil
}
func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rules []C.Rule) (*DNS, error) {
func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rules []C.Rule, ruleProviders map[string]providerTypes.RuleProvider) (*DNS, error) {
cfg := rawCfg.DNS
if cfg.Enable && len(cfg.NameServer) == 0 {
return nil, fmt.Errorf("if DNS configuration is turned on, NameServer cannot be empty")
@ -1104,7 +1120,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul
return nil, err
}
if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, cfg.PreferH3); err != nil {
if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, ruleProviders, cfg.PreferH3); err != nil {
return nil, err
}

View File

@ -16,6 +16,7 @@ import (
"github.com/Dreamacro/clash/component/resolver"
"github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/constant/provider"
"github.com/Dreamacro/clash/log"
D "github.com/miekg/dns"
@ -40,6 +41,11 @@ type geositePolicyRecord struct {
inversedMatching bool
}
type domainSetPolicyRecord struct {
domainSetProvider provider.RuleProvider
policy *Policy
}
type Resolver struct {
ipv6 bool
ipv6Timeout time.Duration
@ -51,6 +57,7 @@ type Resolver struct {
group singleflight.Group
lruCache *cache.LruCache[string, *D.Msg]
policy *trie.DomainTrie[*Policy]
domainSetPolicy []domainSetPolicyRecord
geositePolicy []geositePolicyRecord
proxyServer []dnsClient
}
@ -301,6 +308,12 @@ func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient {
return geositeRecord.policy.GetData()
}
}
metadata := &C.Metadata{Host: domain}
for _, domainSetRecord := range r.domainSetPolicy {
if ok := domainSetRecord.domainSetProvider.Match(metadata); ok {
return domainSetRecord.policy.GetData()
}
}
return nil
}
@ -422,16 +435,18 @@ type FallbackFilter struct {
}
type Config struct {
Main, Fallback []NameServer
Default []NameServer
ProxyServer []NameServer
IPv6 bool
IPv6Timeout uint
EnhancedMode C.DNSMode
FallbackFilter FallbackFilter
Pool *fakeip.Pool
Hosts *trie.DomainTrie[resolver.HostValue]
Policy map[string][]NameServer
Main, Fallback []NameServer
Default []NameServer
ProxyServer []NameServer
IPv6 bool
IPv6Timeout uint
EnhancedMode C.DNSMode
FallbackFilter FallbackFilter
Pool *fakeip.Pool
Hosts *trie.DomainTrie[resolver.HostValue]
Policy map[string][]NameServer
DomainSetPolicy map[provider.RuleProvider][]NameServer
GeositePolicy map[router.DomainMatcher][]NameServer
}
func NewResolver(config Config) *Resolver {
@ -483,6 +498,14 @@ func NewResolver(config Config) *Resolver {
}
r.policy.Optimize()
}
if len(config.DomainSetPolicy) > 0 {
for p, n := range config.DomainSetPolicy {
r.domainSetPolicy = append(r.domainSetPolicy, domainSetPolicyRecord{
domainSetProvider: p,
policy: NewPolicy(transform(n, defaultResolver)),
})
}
}
fallbackIPFilters := []fallbackIPFilter{}
if config.FallbackFilter.GeoIP {

1
go.mod
View File

@ -67,6 +67,7 @@ require (
github.com/mdlayher/socket v0.4.0 // indirect
github.com/metacubex/gvisor v0.0.0-20230323114922-412956fb6a03 // indirect
github.com/onsi/ginkgo/v2 v2.2.0 // indirect
github.com/openacid/low v0.1.21
github.com/oschwald/maxminddb-golang v1.10.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qpack v0.4.0 // indirect

9
go.sum
View File

@ -6,6 +6,7 @@ github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
@ -34,6 +35,7 @@ github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1
github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
@ -108,10 +110,16 @@ github.com/mroth/weightedrand/v2 v2.0.0/go.mod h1:f2faGsfOGOwc1p94wzHKKZyTpcJUW7
github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI=
github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk=
github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q=
github.com/openacid/errors v0.8.1/go.mod h1:GUQEJJOJE3W9skHm8E8Y4phdl2LLEN8iD7c5gcGgdx0=
github.com/openacid/low v0.1.21 h1:Tr2GNu4N/+rGRYdOsEHOE89cxUIaDViZbVmKz29uKGo=
github.com/openacid/low v0.1.21/go.mod h1:q+MsKI6Pz2xsCkzV4BLj7NR5M4EX0sGz5AqotpZDVh0=
github.com/openacid/must v0.1.3/go.mod h1:luPiXCuJlEo3UUFQngVQokV0MPGryeYvtCbQPs3U1+I=
github.com/openacid/testkeys v0.1.6/go.mod h1:MfA7cACzBpbiwekivj8StqX0WIRmqlMsci1c37CA3Do=
github.com/oschwald/geoip2-golang v1.8.0 h1:KfjYB8ojCEn/QLqsDU0AzrJ3R5Qa9vFlx3z6SLNcKTs=
github.com/oschwald/geoip2-golang v1.8.0/go.mod h1:R7bRvYjOeaoenAp9sKRS8GX5bJWcZ0laWO5+DauEktw=
github.com/oschwald/maxminddb-golang v1.10.0 h1:Xp1u0ZhqkSuopaKmk1WwHtjF0H9Hd9181uj2MQ5Vndg=
github.com/oschwald/maxminddb-golang v1.10.0/go.mod h1:Y2ELenReaLAZ0b400URyGwvYxHV1dLIxBuyOsyYjHK0=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
@ -148,6 +156,7 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=

View File

@ -5,6 +5,7 @@ import (
"net/netip"
"os"
"runtime"
"strings"
"sync"
"github.com/Dreamacro/clash/adapter"
@ -91,7 +92,7 @@ func ApplyConfig(cfg *config.Config, force bool) {
updateSniffer(cfg.Sniffer)
updateHosts(cfg.Hosts)
updateGeneral(cfg.General)
updateDNS(cfg.DNS, cfg.General.IPv6)
updateDNS(cfg.DNS, cfg.RuleProviders, cfg.General.IPv6)
updateListeners(cfg.General, cfg.Listeners, force)
updateIPTables(cfg)
updateTun(cfg.General)
@ -178,7 +179,7 @@ func updateExperimental(c *config.Config) {
runtime.GC()
}
func updateDNS(c *config.DNS, generalIPv6 bool) {
func updateDNS(c *config.DNS, ruleProvider map[string]provider.RuleProvider, generalIPv6 bool) {
if !c.Enable {
resolver.DefaultResolver = nil
resolver.DefaultHostMapper = nil
@ -186,7 +187,25 @@ func updateDNS(c *config.DNS, generalIPv6 bool) {
dns.ReCreateServer("", nil, nil)
return
}
policy := make(map[string][]dns.NameServer)
domainSetPolicies := make(map[provider.RuleProvider][]dns.NameServer)
for key, nameservers := range c.NameServerPolicy {
temp := strings.Split(key, ":")
if len(temp) == 2 {
prefix := temp[0]
key := temp[1]
switch strings.ToLower(prefix) {
case "domain-set":
if p, ok := ruleProvider[key]; ok {
domainSetPolicies[p] = nameservers
}
case "geosite":
// TODO:
}
} else {
policy[key] = nameservers
}
}
cfg := dns.Config{
Main: c.NameServer,
Fallback: c.Fallback,
@ -205,6 +224,7 @@ func updateDNS(c *config.DNS, generalIPv6 bool) {
Default: c.DefaultNameserver,
Policy: c.NameServerPolicy,
ProxyServer: c.ProxyServerNameserver,
DomainSetPolicy: domainSetPolicies,
}
r := dns.NewResolver(cfg)

View File

@ -3,13 +3,11 @@ package provider
import (
"github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
"golang.org/x/net/idna"
)
type domainStrategy struct {
count int
domainRules *trie.DomainTrie[struct{}]
domainRules *trie.Set
}
func (d *domainStrategy) ShouldFindProcess() bool {
@ -17,7 +15,7 @@ func (d *domainStrategy) ShouldFindProcess() bool {
}
func (d *domainStrategy) Match(metadata *C.Metadata) bool {
return d.domainRules != nil && d.domainRules.Search(metadata.RuleHost()) != nil
return d.domainRules != nil && d.domainRules.Has(metadata.RuleHost())
}
func (d *domainStrategy) Count() int {
@ -29,21 +27,9 @@ func (d *domainStrategy) ShouldResolveIP() bool {
}
func (d *domainStrategy) OnUpdate(rules []string) {
domainTrie := trie.New[struct{}]()
count := 0
for _, rule := range rules {
actualDomain, _ := idna.ToASCII(rule)
err := domainTrie.Insert(actualDomain, struct{}{})
if err != nil {
log.Warnln("invalid domain:[%s]", rule)
} else {
count++
}
}
domainTrie.Optimize()
domainTrie := trie.NewDomainTrieSet(rules)
d.domainRules = domainTrie
d.count = count
d.count = len(rules)
}
func NewDomainStrategy() *domainStrategy {