diff --git a/common/utils/strings.go b/common/utils/strings.go new file mode 100644 index 00000000..04ed7a2b --- /dev/null +++ b/common/utils/strings.go @@ -0,0 +1,8 @@ +package utils +func Reverse(s string) string { + a := []rune(s) + for i, j := 0, len(a)-1; i < j; i, j = i+1, j-1 { + a[i], a[j] = a[j], a[i] + } + return string(a) +} diff --git a/component/trie/sskv.go b/component/trie/sskv.go new file mode 100644 index 00000000..757822b3 --- /dev/null +++ b/component/trie/sskv.go @@ -0,0 +1,158 @@ +// Package succinct provides several succinct data types. +// Modify from https://github.com/openacid/succinct/sskv.go +package trie + +import ( + "sort" + "strings" + + "github.com/Dreamacro/clash/common/utils" + "github.com/openacid/low/bitmap" +) + +const ( + complexWildcardByte = byte('+') + wildcardByte = byte('*') + domainStepByte = byte('.') +) + +type Set struct { + leaves, labelBitmap []uint64 + labels []byte + ranks, selects []int32 +} + +// NewSet creates a new *Set struct, from a slice of sorted strings. +func NewDomainTrieSet(keys []string) *Set { + filter := make(map[string]struct{}, len(keys)) + reserveDomains := make([]string, 0, len(keys)) + for _, key := range keys { + items, ok := ValidAndSplitDomain(key) + if !ok { + continue + } + if items[0] == complexWildcard { + domain := strings.Join(items[1:], domainStep) + reserveDomain := utils.Reverse(domain) + filter[reserveDomain] = struct{}{} + reserveDomains = append(reserveDomains, reserveDomain) + } + + domain := strings.Join(items, domainStep) + reserveDomain := utils.Reverse(domain) + filter[reserveDomain] = struct{}{} + reserveDomains = append(reserveDomains, reserveDomain) + } + sort.Strings(reserveDomains) + keys=reserveDomains + ss := &Set{} + lIdx := 0 + + type qElt struct{ s, e, col int } + queue := []qElt{{0, len(keys), 0}} + for i := 0; i < len(queue); i++ { + elt := queue[i] + if elt.col == len(keys[elt.s]) { + elt.s++ + // a leaf node + setBit(&ss.leaves, i, 1) + } + + for j := elt.s; j < elt.e; { + + frm := j + + for ; j < elt.e && keys[j][elt.col] == keys[frm][elt.col]; j++ { + } + queue = append(queue, qElt{frm, j, elt.col + 1}) + ss.labels = append(ss.labels, keys[frm][elt.col]) + setBit(&ss.labelBitmap, lIdx, 0) + lIdx++ + } + setBit(&ss.labelBitmap, lIdx, 1) + lIdx++ + } + + ss.init() + return ss +} + +// Has query for a key and return whether it presents in the Set. +func (ss *Set) Has(key string) bool { + key=utils.Reverse(key) + // no more labels in this node + // skip character matching + // go to next level + nodeId, bmIdx := 0, 0 + + for i := 0; i < len(key); i++ { + c := key[i] + for ; ; bmIdx++ { + if getBit(ss.labelBitmap, bmIdx) != 0 { + return false + } + // handle wildcard for domain + if ss.labels[bmIdx-nodeId] == complexWildcardByte { + return true + } else if ss.labels[bmIdx-nodeId] == wildcardByte { + j := i + for ; j < len(key); j++ { + if key[j] == domainStepByte { + i = j + goto END + } + } + return true + } else if ss.labels[bmIdx-nodeId] == c { + break + } + } + END: + nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1) + bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1 + } + + if getBit(ss.leaves, nodeId) != 0 { + return true + } else { + return false + } +} + +func setBit(bm *[]uint64, i int, v int) { + for i>>6 >= len(*bm) { + *bm = append(*bm, 0) + } + (*bm)[i>>6] |= uint64(v) << uint(i&63) +} + +func getBit(bm []uint64, i int) uint64 { + return bm[i>>6] & (1 << uint(i&63)) +} + +// init builds pre-calculated cache to speed up rank() and select() +func (ss *Set) init() { + ss.selects, ss.ranks = bitmap.IndexSelect32R64(ss.labelBitmap) +} + +// countZeros counts the number of "0" in a bitmap before the i-th bit(excluding +// the i-th bit) on behalf of rank index. +// E.g.: +// +// countZeros("010010", 4) == 3 +// // 012345 +func countZeros(bm []uint64, ranks []int32, i int) int { + a, _ := bitmap.Rank64(bm, ranks, int32(i)) + return i - int(a) +} + +// selectIthOne returns the index of the i-th "1" in a bitmap, on behalf of rank +// and select indexes. +// E.g.: +// +// selectIthOne("010010", 1) == 4 +// // 012345 +func selectIthOne(bm []uint64, ranks, selects []int32, i int) int { + a, _ := bitmap.Select32R64(bm, selects, ranks, int32(i)) + return int(a) +} diff --git a/config/config.go b/config/config.go index c407aad5..f036097e 100644 --- a/config/config.go +++ b/config/config.go @@ -490,7 +490,7 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) { } config.Hosts = hosts - dnsCfg, err := parseDNS(rawCfg, hosts, rules) + dnsCfg, err := parseDNS(rawCfg, hosts, rules, ruleProviders) if err != nil { return nil, err } @@ -983,7 +983,7 @@ func parsePureDNSServer(server string) string { } } } -func parseNameServerPolicy(nsPolicy map[string]any, preferH3 bool) (map[string][]dns.NameServer, error) { +func parseNameServerPolicy(nsPolicy map[string]any, ruleProviders map[string]providerTypes.RuleProvider, preferH3 bool) (map[string][]dns.NameServer, error) { policy := map[string][]dns.NameServer{} updatedPolicy := make(map[string]interface{}) re := regexp.MustCompile(`[a-zA-Z0-9\-]+\.[a-zA-Z]{2,}(\.[a-zA-Z]{2,})?`) @@ -998,6 +998,14 @@ func parseNameServerPolicy(nsPolicy map[string]any, preferH3 bool) (map[string][ newKey := "geosite:" + subkey updatedPolicy[newKey] = v } + } else if strings.Contains(k, "domain-set:") { + subkeys := strings.Split(k, ":") + subkeys = subkeys[1:] + subkeys = strings.Split(subkeys[0], ",") + for _, subkey := range subkeys { + newKey := "domain-set:" + subkey + updatedPolicy[newKey] = v + } } else if re.MatchString(k) { subkeys := strings.Split(k, ",") for _, subkey := range subkeys { @@ -1021,6 +1029,14 @@ func parseNameServerPolicy(nsPolicy map[string]any, preferH3 bool) (map[string][ if _, valid := trie.ValidAndSplitDomain(domain); !valid { return nil, fmt.Errorf("DNS ResoverRule invalid domain: %s", domain) } + if strings.HasPrefix(domain, "domain-set:") { + domainSetName := domain[11:] + if provider, ok := ruleProviders[domainSetName]; !ok { + return nil, fmt.Errorf("not found domain-set: %s", domainSetName) + } else if provider.Behavior() != providerTypes.Domain { + return nil, fmt.Errorf("rule provider type error, except domain,actual %s", provider.Behavior()) + } + } policy[domain] = nameservers } @@ -1077,7 +1093,7 @@ func parseFallbackGeoSite(countries []string, rules []C.Rule) ([]*router.DomainM return sites, nil } -func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rules []C.Rule) (*DNS, error) { +func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rules []C.Rule, ruleProviders map[string]providerTypes.RuleProvider) (*DNS, error) { cfg := rawCfg.DNS if cfg.Enable && len(cfg.NameServer) == 0 { return nil, fmt.Errorf("if DNS configuration is turned on, NameServer cannot be empty") @@ -1104,7 +1120,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul return nil, err } - if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, cfg.PreferH3); err != nil { + if dnsCfg.NameServerPolicy, err = parseNameServerPolicy(cfg.NameServerPolicy, ruleProviders, cfg.PreferH3); err != nil { return nil, err } diff --git a/dns/resolver.go b/dns/resolver.go index c16aad40..b5a09fd0 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -16,6 +16,7 @@ import ( "github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/constant/provider" "github.com/Dreamacro/clash/log" D "github.com/miekg/dns" @@ -40,6 +41,11 @@ type geositePolicyRecord struct { inversedMatching bool } +type domainSetPolicyRecord struct { + domainSetProvider provider.RuleProvider + policy *Policy +} + type Resolver struct { ipv6 bool ipv6Timeout time.Duration @@ -51,6 +57,7 @@ type Resolver struct { group singleflight.Group lruCache *cache.LruCache[string, *D.Msg] policy *trie.DomainTrie[*Policy] + domainSetPolicy []domainSetPolicyRecord geositePolicy []geositePolicyRecord proxyServer []dnsClient } @@ -301,6 +308,12 @@ func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { return geositeRecord.policy.GetData() } } + metadata := &C.Metadata{Host: domain} + for _, domainSetRecord := range r.domainSetPolicy { + if ok := domainSetRecord.domainSetProvider.Match(metadata); ok { + return domainSetRecord.policy.GetData() + } + } return nil } @@ -422,16 +435,18 @@ type FallbackFilter struct { } type Config struct { - Main, Fallback []NameServer - Default []NameServer - ProxyServer []NameServer - IPv6 bool - IPv6Timeout uint - EnhancedMode C.DNSMode - FallbackFilter FallbackFilter - Pool *fakeip.Pool - Hosts *trie.DomainTrie[resolver.HostValue] - Policy map[string][]NameServer + Main, Fallback []NameServer + Default []NameServer + ProxyServer []NameServer + IPv6 bool + IPv6Timeout uint + EnhancedMode C.DNSMode + FallbackFilter FallbackFilter + Pool *fakeip.Pool + Hosts *trie.DomainTrie[resolver.HostValue] + Policy map[string][]NameServer + DomainSetPolicy map[provider.RuleProvider][]NameServer + GeositePolicy map[router.DomainMatcher][]NameServer } func NewResolver(config Config) *Resolver { @@ -483,6 +498,14 @@ func NewResolver(config Config) *Resolver { } r.policy.Optimize() } + if len(config.DomainSetPolicy) > 0 { + for p, n := range config.DomainSetPolicy { + r.domainSetPolicy = append(r.domainSetPolicy, domainSetPolicyRecord{ + domainSetProvider: p, + policy: NewPolicy(transform(n, defaultResolver)), + }) + } + } fallbackIPFilters := []fallbackIPFilter{} if config.FallbackFilter.GeoIP { diff --git a/go.mod b/go.mod index c8eb6e10..2667a532 100644 --- a/go.mod +++ b/go.mod @@ -67,6 +67,7 @@ require ( github.com/mdlayher/socket v0.4.0 // indirect github.com/metacubex/gvisor v0.0.0-20230323114922-412956fb6a03 // indirect github.com/onsi/ginkgo/v2 v2.2.0 // indirect + github.com/openacid/low v0.1.21 github.com/oschwald/maxminddb-golang v1.10.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect diff --git a/go.sum b/go.sum index f56aed7a..0099411f 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,7 @@ github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -34,6 +35,7 @@ github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1 github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= @@ -108,10 +110,16 @@ github.com/mroth/weightedrand/v2 v2.0.0/go.mod h1:f2faGsfOGOwc1p94wzHKKZyTpcJUW7 github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI= github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q= +github.com/openacid/errors v0.8.1/go.mod h1:GUQEJJOJE3W9skHm8E8Y4phdl2LLEN8iD7c5gcGgdx0= +github.com/openacid/low v0.1.21 h1:Tr2GNu4N/+rGRYdOsEHOE89cxUIaDViZbVmKz29uKGo= +github.com/openacid/low v0.1.21/go.mod h1:q+MsKI6Pz2xsCkzV4BLj7NR5M4EX0sGz5AqotpZDVh0= +github.com/openacid/must v0.1.3/go.mod h1:luPiXCuJlEo3UUFQngVQokV0MPGryeYvtCbQPs3U1+I= +github.com/openacid/testkeys v0.1.6/go.mod h1:MfA7cACzBpbiwekivj8StqX0WIRmqlMsci1c37CA3Do= github.com/oschwald/geoip2-golang v1.8.0 h1:KfjYB8ojCEn/QLqsDU0AzrJ3R5Qa9vFlx3z6SLNcKTs= github.com/oschwald/geoip2-golang v1.8.0/go.mod h1:R7bRvYjOeaoenAp9sKRS8GX5bJWcZ0laWO5+DauEktw= github.com/oschwald/maxminddb-golang v1.10.0 h1:Xp1u0ZhqkSuopaKmk1WwHtjF0H9Hd9181uj2MQ5Vndg= github.com/oschwald/maxminddb-golang v1.10.0/go.mod h1:Y2ELenReaLAZ0b400URyGwvYxHV1dLIxBuyOsyYjHK0= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= @@ -148,6 +156,7 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 1b2ec572..19612fee 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -5,6 +5,7 @@ import ( "net/netip" "os" "runtime" + "strings" "sync" "github.com/Dreamacro/clash/adapter" @@ -91,7 +92,7 @@ func ApplyConfig(cfg *config.Config, force bool) { updateSniffer(cfg.Sniffer) updateHosts(cfg.Hosts) updateGeneral(cfg.General) - updateDNS(cfg.DNS, cfg.General.IPv6) + updateDNS(cfg.DNS, cfg.RuleProviders, cfg.General.IPv6) updateListeners(cfg.General, cfg.Listeners, force) updateIPTables(cfg) updateTun(cfg.General) @@ -178,7 +179,7 @@ func updateExperimental(c *config.Config) { runtime.GC() } -func updateDNS(c *config.DNS, generalIPv6 bool) { +func updateDNS(c *config.DNS, ruleProvider map[string]provider.RuleProvider, generalIPv6 bool) { if !c.Enable { resolver.DefaultResolver = nil resolver.DefaultHostMapper = nil @@ -186,7 +187,25 @@ func updateDNS(c *config.DNS, generalIPv6 bool) { dns.ReCreateServer("", nil, nil) return } - + policy := make(map[string][]dns.NameServer) + domainSetPolicies := make(map[provider.RuleProvider][]dns.NameServer) + for key, nameservers := range c.NameServerPolicy { + temp := strings.Split(key, ":") + if len(temp) == 2 { + prefix := temp[0] + key := temp[1] + switch strings.ToLower(prefix) { + case "domain-set": + if p, ok := ruleProvider[key]; ok { + domainSetPolicies[p] = nameservers + } + case "geosite": + // TODO: + } + } else { + policy[key] = nameservers + } + } cfg := dns.Config{ Main: c.NameServer, Fallback: c.Fallback, @@ -205,6 +224,7 @@ func updateDNS(c *config.DNS, generalIPv6 bool) { Default: c.DefaultNameserver, Policy: c.NameServerPolicy, ProxyServer: c.ProxyServerNameserver, + DomainSetPolicy: domainSetPolicies, } r := dns.NewResolver(cfg) diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go index add64e76..9f4ab0d8 100644 --- a/rules/provider/domain_strategy.go +++ b/rules/provider/domain_strategy.go @@ -3,13 +3,11 @@ package provider import ( "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" - "github.com/Dreamacro/clash/log" - "golang.org/x/net/idna" ) type domainStrategy struct { count int - domainRules *trie.DomainTrie[struct{}] + domainRules *trie.Set } func (d *domainStrategy) ShouldFindProcess() bool { @@ -17,7 +15,7 @@ func (d *domainStrategy) ShouldFindProcess() bool { } func (d *domainStrategy) Match(metadata *C.Metadata) bool { - return d.domainRules != nil && d.domainRules.Search(metadata.RuleHost()) != nil + return d.domainRules != nil && d.domainRules.Has(metadata.RuleHost()) } func (d *domainStrategy) Count() int { @@ -29,21 +27,9 @@ func (d *domainStrategy) ShouldResolveIP() bool { } func (d *domainStrategy) OnUpdate(rules []string) { - domainTrie := trie.New[struct{}]() - count := 0 - for _, rule := range rules { - actualDomain, _ := idna.ToASCII(rule) - err := domainTrie.Insert(actualDomain, struct{}{}) - if err != nil { - log.Warnln("invalid domain:[%s]", rule) - } else { - count++ - } - } - domainTrie.Optimize() - + domainTrie := trie.NewDomainTrieSet(rules) d.domainRules = domainTrie - d.count = count + d.count = len(rules) } func NewDomainStrategy() *domainStrategy {