Tests: Improve geosite & geoip tests (#5502)

https://github.com/XTLS/Xray-core/pull/5488#issuecomment-3711843548
This commit is contained in:
Hossin Asaadi
2026-01-08 11:30:49 +03:30
committed by GitHub
parent 39ba1f7952
commit 36425d2a6e
6 changed files with 103 additions and 92 deletions

View File

@@ -1,40 +1,17 @@
package router_test package router_test
import ( import (
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"testing" "testing"
"github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform" "github.com/xtls/xray-core/infra/conf"
"github.com/xtls/xray-core/common/platform/filesystem"
"google.golang.org/protobuf/proto"
) )
func getAssetPath(file string) (string, error) {
path := platform.GetAssetLocation(file)
_, err := os.Stat(path)
if os.IsNotExist(err) {
path := filepath.Join("..", "..", "resources", file)
_, err := os.Stat(path)
if os.IsNotExist(err) {
return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file)
}
if err != nil {
return "", fmt.Errorf("can't stat %s: %v", path, err)
}
return path, nil
}
if err != nil {
return "", fmt.Errorf("can't stat %s: %v", path, err)
}
return path, nil
}
func TestGeoIPMatcher(t *testing.T) { func TestGeoIPMatcher(t *testing.T) {
cidrList := []*router.CIDR{ cidrList := []*router.CIDR{
{Ip: []byte{0, 0, 0, 0}, Prefix: 8}, {Ip: []byte{0, 0, 0, 0}, Prefix: 8},
@@ -182,12 +159,11 @@ func TestGeoIPReverseMatcher(t *testing.T) {
} }
func TestGeoIPMatcher4CN(t *testing.T) { func TestGeoIPMatcher4CN(t *testing.T) {
ips, err := loadGeoIP("CN") geo := "geoip:cn"
geoip, err := loadGeoIP(geo)
common.Must(err) common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
Cidr: ips,
})
common.Must(err) common.Must(err)
if matcher.Match([]byte{8, 8, 8, 8}) { if matcher.Match([]byte{8, 8, 8, 8}) {
@@ -196,12 +172,11 @@ func TestGeoIPMatcher4CN(t *testing.T) {
} }
func TestGeoIPMatcher6US(t *testing.T) { func TestGeoIPMatcher6US(t *testing.T) {
ips, err := loadGeoIP("US") geo := "geoip:us"
geoip, err := loadGeoIP(geo)
common.Must(err) common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
Cidr: ips,
})
common.Must(err) common.Must(err)
if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) { if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) {
@@ -209,37 +184,34 @@ func TestGeoIPMatcher6US(t *testing.T) {
} }
} }
func loadGeoIP(country string) ([]*router.CIDR, error) { func loadGeoIP(geo string) (*router.GeoIP, error) {
path, err := getAssetPath("geoip.dat") os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
if err != nil {
return nil, err geoip, err := conf.ToCidrList([]string{geo})
}
geoipBytes, err := filesystem.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var geoipList router.GeoIPList if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil { geoip, err = router.GetGeoIPList(geoip)
return nil, err if err != nil {
} return nil, err
for _, geoip := range geoipList.Entry {
if geoip.CountryCode == country {
return geoip.Cidr, nil
} }
} }
panic("country not found: " + country) if len(geoip) == 0 {
panic("country not found: " + geo)
}
return geoip[0], nil
} }
func BenchmarkGeoIPMatcher4CN(b *testing.B) { func BenchmarkGeoIPMatcher4CN(b *testing.B) {
ips, err := loadGeoIP("CN") geo := "geoip:cn"
geoip, err := loadGeoIP(geo)
common.Must(err) common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
Cidr: ips,
})
common.Must(err) common.Must(err)
b.ResetTimer() b.ResetTimer()
@@ -250,12 +222,11 @@ func BenchmarkGeoIPMatcher4CN(b *testing.B) {
} }
func BenchmarkGeoIPMatcher6US(b *testing.B) { func BenchmarkGeoIPMatcher6US(b *testing.B) {
ips, err := loadGeoIP("US") geo := "geoip:us"
geoip, err := loadGeoIP(geo)
common.Must(err) common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{ matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
Cidr: ips,
})
common.Must(err) common.Must(err)
b.ResetTimer() b.ResetTimer()

View File

@@ -1,20 +1,22 @@
package router_test package router_test
import ( import (
"os"
"path/filepath"
"runtime"
"strconv" "strconv"
"testing" "testing"
"github.com/xtls/xray-core/app/router"
. "github.com/xtls/xray-core/app/router" . "github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/protocol/http" "github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
routing_session "github.com/xtls/xray-core/features/routing/session" routing_session "github.com/xtls/xray-core/features/routing/session"
"google.golang.org/protobuf/proto" "github.com/xtls/xray-core/infra/conf"
) )
func withBackground() routing.Context { func withBackground() routing.Context {
@@ -300,32 +302,25 @@ func TestRoutingRule(t *testing.T) {
} }
} }
func loadGeoSite(country string) ([]*Domain, error) { func loadGeoSiteDomains(geo string) ([]*Domain, error) {
path, err := getAssetPath("geosite.dat") os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
if err != nil {
return nil, err domains, err := conf.ParseDomainRule(geo)
}
geositeBytes, err := filesystem.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var geositeList GeoSiteList if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil { domains, err = router.GetDomainList(domains)
return nil, err if err != nil {
} return nil, err
for _, site := range geositeList.Entry {
if site.CountryCode == country {
return site.Domain, nil
} }
} }
return domains, nil
return nil, errors.New("country not found: " + country)
} }
func TestChinaSites(t *testing.T) { func TestChinaSites(t *testing.T) {
domains, err := loadGeoSite("CN") domains, err := loadGeoSiteDomains("geosite:cn")
common.Must(err) common.Must(err)
acMatcher, err := NewMphMatcherGroup(domains) acMatcher, err := NewMphMatcherGroup(domains)
@@ -366,8 +361,50 @@ func TestChinaSites(t *testing.T) {
} }
} }
func TestChinaSitesWithAttrs(t *testing.T) {
domains, err := loadGeoSiteDomains("geosite:google@cn")
common.Must(err)
acMatcher, err := NewMphMatcherGroup(domains)
common.Must(err)
type TestCase struct {
Domain string
Output bool
}
testCases := []TestCase{
{
Domain: "google.cn",
Output: true,
},
{
Domain: "recaptcha.net",
Output: true,
},
{
Domain: "164.com",
Output: false,
},
{
Domain: "164.com",
Output: false,
},
}
for i := 0; i < 1024; i++ {
testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false})
}
for _, testCase := range testCases {
r := acMatcher.ApplyDomain(testCase.Domain)
if r != testCase.Output {
t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r)
}
}
}
func BenchmarkMphDomainMatcher(b *testing.B) { func BenchmarkMphDomainMatcher(b *testing.B) {
domains, err := loadGeoSite("CN") domains, err := loadGeoSiteDomains("geosite:cn")
common.Must(err) common.Must(err)
matcher, err := NewMphMatcherGroup(domains) matcher, err := NewMphMatcherGroup(domains)
@@ -412,11 +449,11 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
var geoips []*GeoIP var geoips []*GeoIP
{ {
ips, err := loadGeoIP("CN") ips, err := loadGeoIP("geoip:cn")
common.Must(err) common.Must(err)
geoips = append(geoips, &GeoIP{ geoips = append(geoips, &GeoIP{
CountryCode: "CN", CountryCode: "CN",
Cidr: ips, Cidr: ips.Cidr,
}) })
} }
@@ -425,25 +462,25 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
common.Must(err) common.Must(err)
geoips = append(geoips, &GeoIP{ geoips = append(geoips, &GeoIP{
CountryCode: "JP", CountryCode: "JP",
Cidr: ips, Cidr: ips.Cidr,
}) })
} }
{ {
ips, err := loadGeoIP("CA") ips, err := loadGeoIP("geoip:ca")
common.Must(err) common.Must(err)
geoips = append(geoips, &GeoIP{ geoips = append(geoips, &GeoIP{
CountryCode: "CA", CountryCode: "CA",
Cidr: ips, Cidr: ips.Cidr,
}) })
} }
{ {
ips, err := loadGeoIP("US") ips, err := loadGeoIP("geoip:us")
common.Must(err) common.Must(err)
geoips = append(geoips, &GeoIP{ geoips = append(geoips, &GeoIP{
CountryCode: "US", CountryCode: "US",
Cidr: ips, Cidr: ips.Cidr,
}) })
} }

View File

@@ -112,7 +112,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
domains := rr.Domain domains := rr.Domain
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" { if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
var err error var err error
domains, err = getDomainList(rr.Domain) domains, err = GetDomainList(rr.Domain)
if err != nil { if err != nil {
return nil, errors.New("failed to build domains from mmap").Base(err) return nil, errors.New("failed to build domains from mmap").Base(err)
} }
@@ -122,7 +122,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
if err != nil { if err != nil {
return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err) return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err)
} }
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)") errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(domains), " domain rule(s)")
conds.Add(matcher) conds.Add(matcher)
} }
@@ -214,7 +214,7 @@ func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
} }
func getDomainList(domains []*Domain) ([]*Domain, error) { func GetDomainList(domains []*Domain) ([]*Domain, error) {
domainList := []*Domain{} domainList := []*Domain{}
for _, domain := range domains { for _, domain := range domains {
val := strings.Split(domain.Value, "_") val := strings.Split(domain.Value, "_")

View File

@@ -3,7 +3,9 @@
package platform package platform
import "path/filepath" import (
"path/filepath"
)
func LineSeparator() string { func LineSeparator() string {
return "\r\n" return "\r\n"
@@ -12,6 +14,7 @@ func LineSeparator() string {
// GetAssetLocation searches for `file` in the env dir and the executable dir // GetAssetLocation searches for `file` in the env dir and the executable dir
func GetAssetLocation(file string) string { func GetAssetLocation(file string) string {
assetPath := NewEnvFlag(AssetLocation).GetValue(getExecutableDir) assetPath := NewEnvFlag(AssetLocation).GetValue(getExecutableDir)
return filepath.Join(assetPath, file) return filepath.Join(assetPath, file)
} }

View File

@@ -89,7 +89,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) {
var originalRules []*dns.NameServer_OriginalRule var originalRules []*dns.NameServer_OriginalRule
for _, rule := range c.Domains { for _, rule := range c.Domains {
parsedDomain, err := parseDomainRule(rule) parsedDomain, err := ParseDomainRule(rule)
if err != nil { if err != nil {
return nil, errors.New("invalid domain rule: ", rule).Base(err) return nil, errors.New("invalid domain rule: ", rule).Base(err)
} }

View File

@@ -291,7 +291,7 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
return filteredDomains, nil return filteredDomains, nil
} }
func parseDomainRule(domain string) ([]*router.Domain, error) { func ParseDomainRule(domain string) ([]*router.Domain, error) {
if strings.HasPrefix(domain, "geosite:") { if strings.HasPrefix(domain, "geosite:") {
country := strings.ToUpper(domain[8:]) country := strings.ToUpper(domain[8:])
domains, err := loadGeositeWithAttr("geosite.dat", country) domains, err := loadGeositeWithAttr("geosite.dat", country)
@@ -489,7 +489,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
if rawFieldRule.Domain != nil { if rawFieldRule.Domain != nil {
for _, domain := range *rawFieldRule.Domain { for _, domain := range *rawFieldRule.Domain {
rules, err := parseDomainRule(domain) rules, err := ParseDomainRule(domain)
if err != nil { if err != nil {
return nil, errors.New("failed to parse domain rule: ", domain).Base(err) return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
} }
@@ -499,7 +499,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
if rawFieldRule.Domains != nil { if rawFieldRule.Domains != nil {
for _, domain := range *rawFieldRule.Domains { for _, domain := range *rawFieldRule.Domains {
rules, err := parseDomainRule(domain) rules, err := ParseDomainRule(domain)
if err != nil { if err != nil {
return nil, errors.New("failed to parse domain rule: ", domain).Base(err) return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
} }