Tests: Improve geosite & geoip tests (#5502)

https://github.com/XTLS/Xray-core/pull/5488#issuecomment-3711843548
This commit is contained in:
Hossin Asaadi
2026-01-08 11:30:49 +03:30
committed by GitHub
parent 39ba1f7952
commit 36425d2a6e
6 changed files with 103 additions and 92 deletions

View File

@@ -1,40 +1,17 @@
package router_test
import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform"
"github.com/xtls/xray-core/common/platform/filesystem"
"google.golang.org/protobuf/proto"
"github.com/xtls/xray-core/infra/conf"
)
func getAssetPath(file string) (string, error) {
path := platform.GetAssetLocation(file)
_, err := os.Stat(path)
if os.IsNotExist(err) {
path := filepath.Join("..", "..", "resources", file)
_, err := os.Stat(path)
if os.IsNotExist(err) {
return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file)
}
if err != nil {
return "", fmt.Errorf("can't stat %s: %v", path, err)
}
return path, nil
}
if err != nil {
return "", fmt.Errorf("can't stat %s: %v", path, err)
}
return path, nil
}
func TestGeoIPMatcher(t *testing.T) {
cidrList := []*router.CIDR{
{Ip: []byte{0, 0, 0, 0}, Prefix: 8},
@@ -182,12 +159,11 @@ func TestGeoIPReverseMatcher(t *testing.T) {
}
func TestGeoIPMatcher4CN(t *testing.T) {
ips, err := loadGeoIP("CN")
geo := "geoip:cn"
geoip, err := loadGeoIP(geo)
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
common.Must(err)
if matcher.Match([]byte{8, 8, 8, 8}) {
@@ -196,12 +172,11 @@ func TestGeoIPMatcher4CN(t *testing.T) {
}
func TestGeoIPMatcher6US(t *testing.T) {
ips, err := loadGeoIP("US")
geo := "geoip:us"
geoip, err := loadGeoIP(geo)
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
common.Must(err)
if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) {
@@ -209,37 +184,34 @@ func TestGeoIPMatcher6US(t *testing.T) {
}
}
func loadGeoIP(country string) ([]*router.CIDR, error) {
path, err := getAssetPath("geoip.dat")
if err != nil {
return nil, err
}
geoipBytes, err := filesystem.ReadFile(path)
func loadGeoIP(geo string) (*router.GeoIP, error) {
os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
geoip, err := conf.ToCidrList([]string{geo})
if err != nil {
return nil, err
}
var geoipList router.GeoIPList
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
return nil, err
}
for _, geoip := range geoipList.Entry {
if geoip.CountryCode == country {
return geoip.Cidr, nil
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
geoip, err = router.GetGeoIPList(geoip)
if err != nil {
return nil, err
}
}
panic("country not found: " + country)
if len(geoip) == 0 {
panic("country not found: " + geo)
}
return geoip[0], nil
}
func BenchmarkGeoIPMatcher4CN(b *testing.B) {
ips, err := loadGeoIP("CN")
geo := "geoip:cn"
geoip, err := loadGeoIP(geo)
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
common.Must(err)
b.ResetTimer()
@@ -250,12 +222,11 @@ func BenchmarkGeoIPMatcher4CN(b *testing.B) {
}
func BenchmarkGeoIPMatcher6US(b *testing.B) {
ips, err := loadGeoIP("US")
geo := "geoip:us"
geoip, err := loadGeoIP(geo)
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
common.Must(err)
b.ResetTimer()

View File

@@ -1,20 +1,22 @@
package router_test
import (
"os"
"path/filepath"
"runtime"
"strconv"
"testing"
"github.com/xtls/xray-core/app/router"
. "github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/features/routing"
routing_session "github.com/xtls/xray-core/features/routing/session"
"google.golang.org/protobuf/proto"
"github.com/xtls/xray-core/infra/conf"
)
func withBackground() routing.Context {
@@ -300,32 +302,25 @@ func TestRoutingRule(t *testing.T) {
}
}
func loadGeoSite(country string) ([]*Domain, error) {
path, err := getAssetPath("geosite.dat")
if err != nil {
return nil, err
}
geositeBytes, err := filesystem.ReadFile(path)
func loadGeoSiteDomains(geo string) ([]*Domain, error) {
os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
domains, err := conf.ParseDomainRule(geo)
if err != nil {
return nil, err
}
var geositeList GeoSiteList
if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil {
return nil, err
}
for _, site := range geositeList.Entry {
if site.CountryCode == country {
return site.Domain, nil
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
domains, err = router.GetDomainList(domains)
if err != nil {
return nil, err
}
}
return nil, errors.New("country not found: " + country)
return domains, nil
}
func TestChinaSites(t *testing.T) {
domains, err := loadGeoSite("CN")
domains, err := loadGeoSiteDomains("geosite:cn")
common.Must(err)
acMatcher, err := NewMphMatcherGroup(domains)
@@ -366,8 +361,50 @@ func TestChinaSites(t *testing.T) {
}
}
func TestChinaSitesWithAttrs(t *testing.T) {
domains, err := loadGeoSiteDomains("geosite:google@cn")
common.Must(err)
acMatcher, err := NewMphMatcherGroup(domains)
common.Must(err)
type TestCase struct {
Domain string
Output bool
}
testCases := []TestCase{
{
Domain: "google.cn",
Output: true,
},
{
Domain: "recaptcha.net",
Output: true,
},
{
Domain: "164.com",
Output: false,
},
{
Domain: "164.com",
Output: false,
},
}
for i := 0; i < 1024; i++ {
testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false})
}
for _, testCase := range testCases {
r := acMatcher.ApplyDomain(testCase.Domain)
if r != testCase.Output {
t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r)
}
}
}
func BenchmarkMphDomainMatcher(b *testing.B) {
domains, err := loadGeoSite("CN")
domains, err := loadGeoSiteDomains("geosite:cn")
common.Must(err)
matcher, err := NewMphMatcherGroup(domains)
@@ -412,11 +449,11 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
var geoips []*GeoIP
{
ips, err := loadGeoIP("CN")
ips, err := loadGeoIP("geoip:cn")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "CN",
Cidr: ips,
Cidr: ips.Cidr,
})
}
@@ -425,25 +462,25 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "JP",
Cidr: ips,
Cidr: ips.Cidr,
})
}
{
ips, err := loadGeoIP("CA")
ips, err := loadGeoIP("geoip:ca")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "CA",
Cidr: ips,
Cidr: ips.Cidr,
})
}
{
ips, err := loadGeoIP("US")
ips, err := loadGeoIP("geoip:us")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "US",
Cidr: ips,
Cidr: ips.Cidr,
})
}

View File

@@ -112,7 +112,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
domains := rr.Domain
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
var err error
domains, err = getDomainList(rr.Domain)
domains, err = GetDomainList(rr.Domain)
if err != nil {
return nil, errors.New("failed to build domains from mmap").Base(err)
}
@@ -122,7 +122,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
if err != nil {
return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err)
}
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)")
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(domains), " domain rule(s)")
conds.Add(matcher)
}
@@ -214,7 +214,7 @@ func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
}
func getDomainList(domains []*Domain) ([]*Domain, error) {
func GetDomainList(domains []*Domain) ([]*Domain, error) {
domainList := []*Domain{}
for _, domain := range domains {
val := strings.Split(domain.Value, "_")

View File

@@ -3,7 +3,9 @@
package platform
import "path/filepath"
import (
"path/filepath"
)
func LineSeparator() string {
return "\r\n"
@@ -12,6 +14,7 @@ func LineSeparator() string {
// GetAssetLocation searches for `file` in the env dir and the executable dir
func GetAssetLocation(file string) string {
assetPath := NewEnvFlag(AssetLocation).GetValue(getExecutableDir)
return filepath.Join(assetPath, file)
}

View File

@@ -89,7 +89,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) {
var originalRules []*dns.NameServer_OriginalRule
for _, rule := range c.Domains {
parsedDomain, err := parseDomainRule(rule)
parsedDomain, err := ParseDomainRule(rule)
if err != nil {
return nil, errors.New("invalid domain rule: ", rule).Base(err)
}

View File

@@ -291,7 +291,7 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
return filteredDomains, nil
}
func parseDomainRule(domain string) ([]*router.Domain, error) {
func ParseDomainRule(domain string) ([]*router.Domain, error) {
if strings.HasPrefix(domain, "geosite:") {
country := strings.ToUpper(domain[8:])
domains, err := loadGeositeWithAttr("geosite.dat", country)
@@ -489,7 +489,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
if rawFieldRule.Domain != nil {
for _, domain := range *rawFieldRule.Domain {
rules, err := parseDomainRule(domain)
rules, err := ParseDomainRule(domain)
if err != nil {
return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
}
@@ -499,7 +499,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
if rawFieldRule.Domains != nil {
for _, domain := range *rawFieldRule.Domains {
rules, err := parseDomainRule(domain)
rules, err := ParseDomainRule(domain)
if err != nil {
return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
}