From 36425d2a6e53c755c96d25ef2323b7fffd1fab3f Mon Sep 17 00:00:00 2001 From: Hossin Asaadi Date: Thu, 8 Jan 2026 11:30:49 +0330 Subject: [PATCH] Tests: Improve geosite & geoip tests (#5502) https://github.com/XTLS/Xray-core/pull/5488#issuecomment-3711843548 --- app/router/condition_geoip_test.go | 83 +++++++++----------------- app/router/condition_test.go | 93 +++++++++++++++++++++--------- app/router/config.go | 6 +- common/platform/windows.go | 5 +- infra/conf/dns.go | 2 +- infra/conf/router.go | 6 +- 6 files changed, 103 insertions(+), 92 deletions(-) diff --git a/app/router/condition_geoip_test.go b/app/router/condition_geoip_test.go index b712db9e..de289a71 100644 --- a/app/router/condition_geoip_test.go +++ b/app/router/condition_geoip_test.go @@ -1,40 +1,17 @@ package router_test import ( - "fmt" "os" "path/filepath" + "runtime" "testing" "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform" - "github.com/xtls/xray-core/common/platform/filesystem" - "google.golang.org/protobuf/proto" + "github.com/xtls/xray-core/infra/conf" ) -func getAssetPath(file string) (string, error) { - path := platform.GetAssetLocation(file) - _, err := os.Stat(path) - if os.IsNotExist(err) { - path := filepath.Join("..", "..", "resources", file) - _, err := os.Stat(path) - if os.IsNotExist(err) { - return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file) - } - if err != nil { - return "", fmt.Errorf("can't stat %s: %v", path, err) - } - return path, nil - } - if err != nil { - return "", fmt.Errorf("can't stat %s: %v", path, err) - } - - return path, nil -} - func TestGeoIPMatcher(t *testing.T) { cidrList := []*router.CIDR{ {Ip: []byte{0, 0, 0, 0}, Prefix: 8}, @@ -182,12 +159,11 @@ func TestGeoIPReverseMatcher(t *testing.T) { } func TestGeoIPMatcher4CN(t *testing.T) { - ips, err := loadGeoIP("CN") + geo := "geoip:cn" + geoip, err := loadGeoIP(geo) common.Must(err) - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: ips, - }) + matcher, err := router.BuildOptimizedGeoIPMatcher(geoip) common.Must(err) if matcher.Match([]byte{8, 8, 8, 8}) { @@ -196,12 +172,11 @@ func TestGeoIPMatcher4CN(t *testing.T) { } func TestGeoIPMatcher6US(t *testing.T) { - ips, err := loadGeoIP("US") + geo := "geoip:us" + geoip, err := loadGeoIP(geo) common.Must(err) - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: ips, - }) + matcher, err := router.BuildOptimizedGeoIPMatcher(geoip) common.Must(err) if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) { @@ -209,37 +184,34 @@ func TestGeoIPMatcher6US(t *testing.T) { } } -func loadGeoIP(country string) ([]*router.CIDR, error) { - path, err := getAssetPath("geoip.dat") - if err != nil { - return nil, err - } - geoipBytes, err := filesystem.ReadFile(path) +func loadGeoIP(geo string) (*router.GeoIP, error) { + os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources")) + + geoip, err := conf.ToCidrList([]string{geo}) if err != nil { return nil, err } - var geoipList router.GeoIPList - if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil { - return nil, err - } - - for _, geoip := range geoipList.Entry { - if geoip.CountryCode == country { - return geoip.Cidr, nil + if runtime.GOOS != "windows" && runtime.GOOS != "wasm" { + geoip, err = router.GetGeoIPList(geoip) + if err != nil { + return nil, err } } - panic("country not found: " + country) + if len(geoip) == 0 { + panic("country not found: " + geo) + } + + return geoip[0], nil } func BenchmarkGeoIPMatcher4CN(b *testing.B) { - ips, err := loadGeoIP("CN") + geo := "geoip:cn" + geoip, err := loadGeoIP(geo) common.Must(err) - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: ips, - }) + matcher, err := router.BuildOptimizedGeoIPMatcher(geoip) common.Must(err) b.ResetTimer() @@ -250,12 +222,11 @@ func BenchmarkGeoIPMatcher4CN(b *testing.B) { } func BenchmarkGeoIPMatcher6US(b *testing.B) { - ips, err := loadGeoIP("US") + geo := "geoip:us" + geoip, err := loadGeoIP(geo) common.Must(err) - matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ - Cidr: ips, - }) + matcher, err := router.BuildOptimizedGeoIPMatcher(geoip) common.Must(err) b.ResetTimer() diff --git a/app/router/condition_test.go b/app/router/condition_test.go index 1272aef6..283e6725 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -1,20 +1,22 @@ package router_test import ( + "os" + "path/filepath" + "runtime" "strconv" "testing" + "github.com/xtls/xray-core/app/router" . "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" - "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol/http" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/features/routing" routing_session "github.com/xtls/xray-core/features/routing/session" - "google.golang.org/protobuf/proto" + "github.com/xtls/xray-core/infra/conf" ) func withBackground() routing.Context { @@ -300,32 +302,25 @@ func TestRoutingRule(t *testing.T) { } } -func loadGeoSite(country string) ([]*Domain, error) { - path, err := getAssetPath("geosite.dat") - if err != nil { - return nil, err - } - geositeBytes, err := filesystem.ReadFile(path) +func loadGeoSiteDomains(geo string) ([]*Domain, error) { + os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources")) + + domains, err := conf.ParseDomainRule(geo) if err != nil { return nil, err } - var geositeList GeoSiteList - if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil { - return nil, err - } - - for _, site := range geositeList.Entry { - if site.CountryCode == country { - return site.Domain, nil + if runtime.GOOS != "windows" && runtime.GOOS != "wasm" { + domains, err = router.GetDomainList(domains) + if err != nil { + return nil, err } } - - return nil, errors.New("country not found: " + country) + return domains, nil } func TestChinaSites(t *testing.T) { - domains, err := loadGeoSite("CN") + domains, err := loadGeoSiteDomains("geosite:cn") common.Must(err) acMatcher, err := NewMphMatcherGroup(domains) @@ -366,8 +361,50 @@ func TestChinaSites(t *testing.T) { } } +func TestChinaSitesWithAttrs(t *testing.T) { + domains, err := loadGeoSiteDomains("geosite:google@cn") + common.Must(err) + + acMatcher, err := NewMphMatcherGroup(domains) + common.Must(err) + + type TestCase struct { + Domain string + Output bool + } + testCases := []TestCase{ + { + Domain: "google.cn", + Output: true, + }, + { + Domain: "recaptcha.net", + Output: true, + }, + { + Domain: "164.com", + Output: false, + }, + { + Domain: "164.com", + Output: false, + }, + } + + for i := 0; i < 1024; i++ { + testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false}) + } + + for _, testCase := range testCases { + r := acMatcher.ApplyDomain(testCase.Domain) + if r != testCase.Output { + t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r) + } + } +} + func BenchmarkMphDomainMatcher(b *testing.B) { - domains, err := loadGeoSite("CN") + domains, err := loadGeoSiteDomains("geosite:cn") common.Must(err) matcher, err := NewMphMatcherGroup(domains) @@ -412,11 +449,11 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) { var geoips []*GeoIP { - ips, err := loadGeoIP("CN") + ips, err := loadGeoIP("geoip:cn") common.Must(err) geoips = append(geoips, &GeoIP{ CountryCode: "CN", - Cidr: ips, + Cidr: ips.Cidr, }) } @@ -425,25 +462,25 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) { common.Must(err) geoips = append(geoips, &GeoIP{ CountryCode: "JP", - Cidr: ips, + Cidr: ips.Cidr, }) } { - ips, err := loadGeoIP("CA") + ips, err := loadGeoIP("geoip:ca") common.Must(err) geoips = append(geoips, &GeoIP{ CountryCode: "CA", - Cidr: ips, + Cidr: ips.Cidr, }) } { - ips, err := loadGeoIP("US") + ips, err := loadGeoIP("geoip:us") common.Must(err) geoips = append(geoips, &GeoIP{ CountryCode: "US", - Cidr: ips, + Cidr: ips.Cidr, }) } diff --git a/app/router/config.go b/app/router/config.go index 1566b560..5399589b 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -112,7 +112,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { domains := rr.Domain if runtime.GOOS != "windows" && runtime.GOOS != "wasm" { var err error - domains, err = getDomainList(rr.Domain) + domains, err = GetDomainList(rr.Domain) if err != nil { return nil, errors.New("failed to build domains from mmap").Base(err) } @@ -122,7 +122,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { if err != nil { return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err) } - errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)") + errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(domains), " domain rule(s)") conds.Add(matcher) } @@ -214,7 +214,7 @@ func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) { } -func getDomainList(domains []*Domain) ([]*Domain, error) { +func GetDomainList(domains []*Domain) ([]*Domain, error) { domainList := []*Domain{} for _, domain := range domains { val := strings.Split(domain.Value, "_") diff --git a/common/platform/windows.go b/common/platform/windows.go index 684ddc9c..5bd15520 100644 --- a/common/platform/windows.go +++ b/common/platform/windows.go @@ -3,7 +3,9 @@ package platform -import "path/filepath" +import ( + "path/filepath" +) func LineSeparator() string { return "\r\n" @@ -12,6 +14,7 @@ func LineSeparator() string { // GetAssetLocation searches for `file` in the env dir and the executable dir func GetAssetLocation(file string) string { assetPath := NewEnvFlag(AssetLocation).GetValue(getExecutableDir) + return filepath.Join(assetPath, file) } diff --git a/infra/conf/dns.go b/infra/conf/dns.go index 6ec307c6..f6d56913 100644 --- a/infra/conf/dns.go +++ b/infra/conf/dns.go @@ -89,7 +89,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) { var originalRules []*dns.NameServer_OriginalRule for _, rule := range c.Domains { - parsedDomain, err := parseDomainRule(rule) + parsedDomain, err := ParseDomainRule(rule) if err != nil { return nil, errors.New("invalid domain rule: ", rule).Base(err) } diff --git a/infra/conf/router.go b/infra/conf/router.go index 54320b3b..c9f5b43b 100644 --- a/infra/conf/router.go +++ b/infra/conf/router.go @@ -291,7 +291,7 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er return filteredDomains, nil } -func parseDomainRule(domain string) ([]*router.Domain, error) { +func ParseDomainRule(domain string) ([]*router.Domain, error) { if strings.HasPrefix(domain, "geosite:") { country := strings.ToUpper(domain[8:]) domains, err := loadGeositeWithAttr("geosite.dat", country) @@ -489,7 +489,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { if rawFieldRule.Domain != nil { for _, domain := range *rawFieldRule.Domain { - rules, err := parseDomainRule(domain) + rules, err := ParseDomainRule(domain) if err != nil { return nil, errors.New("failed to parse domain rule: ", domain).Base(err) } @@ -499,7 +499,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) { if rawFieldRule.Domains != nil { for _, domain := range *rawFieldRule.Domains { - rules, err := parseDomainRule(domain) + rules, err := ParseDomainRule(domain) if err != nil { return nil, errors.New("failed to parse domain rule: ", domain).Base(err) }