diff --git a/component/geodata/utils.go b/component/geodata/utils.go index 3a48dc86..e3cfea43 100644 --- a/component/geodata/utils.go +++ b/component/geodata/utils.go @@ -2,9 +2,11 @@ package geodata import ( "github.com/Dreamacro/clash/component/geodata/router" + + "golang.org/x/exp/maps" ) -func LoadGeoSiteMatcher(countryCode string) (*router.DomainMatcher, int, error) { +func loadGeoSiteMatcher(countryCode string) (*router.DomainMatcher, int, error) { geoLoaderName := "standard" geoLoader, err := GetGeoDataLoader(geoLoaderName) if err != nil { @@ -28,3 +30,33 @@ func LoadGeoSiteMatcher(countryCode string) (*router.DomainMatcher, int, error) return matcher, len(domains), nil } + +var ruleProviders = make(map[string]*router.DomainMatcher) + +// HasProvider has geo site provider by county code +func HasProvider(countyCode string) (ok bool) { + _, ok = ruleProviders[countyCode] + return ok +} + +// GetProvidersList get geo site providers +func GetProvidersList(countyCode string) []*router.DomainMatcher { + return maps.Values(ruleProviders) +} + +// GetProviderByCode get geo site provider by county code +func GetProviderByCode(countyCode string) (matcher *router.DomainMatcher, ok bool) { + matcher, ok = ruleProviders[countyCode] + return +} + +func LoadProviderByCode(countyCode string) (matcher *router.DomainMatcher, count int, err error) { + var ok bool + matcher, ok = ruleProviders[countyCode] + if !ok { + if matcher, count, err = loadGeoSiteMatcher(countyCode); err == nil { + ruleProviders[countyCode] = matcher + } + } + return +} diff --git a/config/config.go b/config/config.go index 9d900452..e70315d2 100644 --- a/config/config.go +++ b/config/config.go @@ -19,6 +19,7 @@ import ( "github.com/Dreamacro/clash/component/fakeip" "github.com/Dreamacro/clash/component/geodata" "github.com/Dreamacro/clash/component/geodata/router" + _ "github.com/Dreamacro/clash/component/geodata/standard" "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" providerTypes "github.com/Dreamacro/clash/constant/provider" @@ -316,7 +317,7 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) { } config.Hosts = hosts - dnsCfg, err := parseDNS(rawCfg, hosts, rules) + dnsCfg, err := parseDNS(rawCfg, hosts) if err != nil { return nil, err } @@ -674,37 +675,27 @@ func parseFallbackIPCIDR(ips []string) ([]*netip.Prefix, error) { return ipNets, nil } -func parseFallbackGeoSite(countries []string, rules []C.Rule) ([]*router.DomainMatcher, error) { +func parseFallbackGeoSite(countries []string) ([]*router.DomainMatcher, error) { var sites []*router.DomainMatcher - for _, country := range countries { - found := false - for _, rule := range rules { - if rule.RuleType() == C.GEOSITE { - if strings.EqualFold(country, rule.Payload()) { - found = true - sites = append(sites, rule.(C.RuleGeoSite).GetDomainMatcher()) - log.Infoln("Start initial GeoSite dns fallback filter from rule `%s`", country) - } - } + matcher, recordsCount, err := geodata.LoadProviderByCode(country) + if err != nil { + return nil, err } - if !found { - matcher, recordsCount, err := geodata.LoadGeoSiteMatcher(country) - if err != nil { - return nil, err - } + sites = append(sites, matcher) - sites = append(sites, matcher) - - log.Infoln("Start initial GeoSite dns fallback filter `%s`, records: %d", country, recordsCount) + cont := fmt.Sprintf("%d", recordsCount) + if recordsCount == 0 { + cont = "from cache" } + log.Infoln("Start initial GeoSite dns fallback filter `%s`, records: %s", country, cont) } runtime.GC() return sites, nil } -func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.Rule) (*DNS, error) { +func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr]) (*DNS, error) { cfg := rawCfg.DNS if cfg.Enable && len(cfg.NameServer) == 0 { return nil, fmt.Errorf("if DNS configuration is turned on, NameServer cannot be empty") @@ -798,7 +789,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.R dnsCfg.FallbackFilter.IPCIDR = fallbackip } dnsCfg.FallbackFilter.Domain = cfg.FallbackFilter.Domain - fallbackGeoSite, err := parseFallbackGeoSite(cfg.FallbackFilter.GeoSite, rules) + fallbackGeoSite, err := parseFallbackGeoSite(cfg.FallbackFilter.GeoSite) if err != nil { return nil, fmt.Errorf("load GeoSite dns fallback filter error, %w", err) } diff --git a/rule/geosite.go b/rule/geosite.go index b765d643..22f39887 100644 --- a/rule/geosite.go +++ b/rule/geosite.go @@ -5,7 +5,6 @@ import ( "github.com/Dreamacro/clash/component/geodata" "github.com/Dreamacro/clash/component/geodata/router" - _ "github.com/Dreamacro/clash/component/geodata/standard" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" ) @@ -47,12 +46,16 @@ func (gs *GEOSITE) GetDomainMatcher() *router.DomainMatcher { } func NewGEOSITE(country string, adapter string) (*GEOSITE, error) { - matcher, recordsCount, err := geodata.LoadGeoSiteMatcher(country) + matcher, recordsCount, err := geodata.LoadProviderByCode(country) if err != nil { return nil, fmt.Errorf("load GeoSite data error, %s", err.Error()) } - log.Infoln("Start initial GeoSite rule %s => %s, records: %d", country, adapter, recordsCount) + cont := fmt.Sprintf("%d", recordsCount) + if recordsCount == 0 { + cont = "from cache" + } + log.Infoln("Start initial GeoSite rule %s => %s, records: %s", country, adapter, cont) geoSite := &GEOSITE{ Base: &Base{},