Compare commits

..

3 Commits

Author SHA1 Message Date
Fangliding
0fc258edd3 Revert "Geofiles: Implement mmap in filesystem to reduce ram usage (#5480)"
This reverts commit 5d94a62a83.
2026-01-17 20:14:01 +08:00
Fangliding
1fc5b7aef5 Revert geofile change
This reverts commit 36425d2a6e.

This reverts commit 961c352127.

This reverts commit c715154309.
2026-01-17 20:10:33 +08:00
风扇滑翔翼
09f619d67c TLS client: Add pin_test.go for leaf and CA (#5553)
https://github.com/XTLS/Xray-core/pull/5532#issuecomment-3760231005
2026-01-17 09:42:06 +00:00
15 changed files with 305 additions and 557 deletions

View File

@@ -12,15 +12,12 @@ import (
"sync"
"time"
router "github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/strmatcher"
"github.com/xtls/xray-core/features/dns"
"google.golang.org/protobuf/proto"
)
// DNS is a DNS rely server.
@@ -100,25 +97,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
}
for _, ns := range config.NameServer {
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
err := parseDomains(ns)
if err != nil {
return nil, errors.New("failed to parse dns domain rules: ").Base(err)
}
expectedGeoip, err := router.GetGeoIPList(ns.ExpectedGeoip)
if err != nil {
return nil, errors.New("failed to parse dns expectIPs rules: ").Base(err)
}
ns.ExpectedGeoip = expectedGeoip
unexpectedGeoip, err := router.GetGeoIPList(ns.UnexpectedGeoip)
if err != nil {
return nil, errors.New("failed to parse dns unexpectedGeoip rules: ").Base(err)
}
ns.UnexpectedGeoip = unexpectedGeoip
}
domainRuleCount += len(ns.PrioritizedDomain)
}
@@ -602,76 +580,3 @@ func detectGUIPlatform() bool {
}
return false
}
func parseDomains(ns *NameServer) error {
pureDomains := []*router.Domain{}
// convert to pure domain
for _, pd := range ns.PrioritizedDomain {
pureDomains = append(pureDomains, &router.Domain{
Type: router.Domain_Type(pd.Type),
Value: pd.Domain,
})
}
domainList := []*router.Domain{}
for _, domain := range pureDomains {
val := strings.Split(domain.Value, "_")
if len(val) >= 2 {
fileName := val[0]
code := val[1]
bs, err := filesystem.ReadAsset(fileName)
if err != nil {
return errors.New("failed to load file: ", fileName).Base(err)
}
bs = filesystem.Find(bs, []byte(code))
var geosite router.GeoSite
if err := proto.Unmarshal(bs, &geosite); err != nil {
return errors.New("failed Unmarshal :").Base(err)
}
// parse attr
if len(val) == 3 {
siteWithAttr := strings.Split(val[2], ",")
attrs := router.ParseAttrs(siteWithAttr)
if !attrs.IsEmpty() {
filteredDomains := make([]*router.Domain, 0, len(pureDomains))
for _, domain := range geosite.Domain {
if attrs.Match(domain) {
filteredDomains = append(filteredDomains, domain)
}
}
geosite.Domain = filteredDomains
}
}
domainList = append(domainList, geosite.Domain...)
// update ns.OriginalRules Size
ruleTag := strings.Join(val, ":")
for i, oRule := range ns.OriginalRules {
if oRule.Rule == strings.ToLower(ruleTag) {
ns.OriginalRules[i].Size = uint32(len(geosite.Domain))
}
}
} else {
domainList = append(domainList, domain)
}
}
// convert back to NameServer_PriorityDomain
ns.PrioritizedDomain = []*NameServer_PriorityDomain{}
for _, pd := range domainList {
ns.PrioritizedDomain = append(ns.PrioritizedDomain, &NameServer_PriorityDomain{
Type: ToDomainMatchingType(pd.Type),
Domain: pd.Value,
})
}
return nil
}

View File

@@ -541,7 +541,7 @@ func TestIPMatch(t *testing.T) {
},
ExpectedGeoip: []*router.GeoIP{
{
// local
CountryCode: "local",
Cidr: []*router.CIDR{
{
// inner ip, will not match
@@ -565,7 +565,7 @@ func TestIPMatch(t *testing.T) {
},
ExpectedGeoip: []*router.GeoIP{
{
// test
CountryCode: "test",
Cidr: []*router.CIDR{
{
Ip: []byte{8, 8, 8, 8},
@@ -574,7 +574,7 @@ func TestIPMatch(t *testing.T) {
},
},
{
// test
CountryCode: "test",
Cidr: []*router.CIDR{
{
Ip: []byte{8, 8, 8, 4},
@@ -669,7 +669,7 @@ func TestLocalDomain(t *testing.T) {
},
ExpectedGeoip: []*router.GeoIP{
{ // Will match localhost, localhost-a and localhost-b,
// local
CountryCode: "local",
Cidr: []*router.CIDR{
{Ip: []byte{127, 0, 0, 2}, Prefix: 32},
{Ip: []byte{127, 0, 0, 3}, Prefix: 32},

View File

@@ -297,18 +297,3 @@ func ResolveIpOptionOverride(queryStrategy QueryStrategy, ipOption dns.IPOption)
return ipOption
}
}
func ToDomainMatchingType(t router.Domain_Type) DomainMatchingType {
switch t {
case router.Domain_Domain:
return DomainMatchingType_Subdomain
case router.Domain_Full:
return DomainMatchingType_Full
case router.Domain_Plain:
return DomainMatchingType_Keyword
case router.Domain_Regex:
return DomainMatchingType_Regex
default:
panic("unknown domain type")
}
}

View File

@@ -309,48 +309,6 @@ func (m *AttributeMatcher) Apply(ctx routing.Context) bool {
return m.Match(attributes)
}
// Geo attribute
type GeoAttributeMatcher interface {
Match(*Domain) bool
}
type GeoBooleanMatcher string
func (m GeoBooleanMatcher) Match(domain *Domain) bool {
for _, attr := range domain.Attribute {
if attr.Key == string(m) {
return true
}
}
return false
}
type GeoAttributeList struct {
Matcher []GeoAttributeMatcher
}
func (al *GeoAttributeList) Match(domain *Domain) bool {
for _, matcher := range al.Matcher {
if !matcher.Match(domain) {
return false
}
}
return true
}
func (al *GeoAttributeList) IsEmpty() bool {
return len(al.Matcher) == 0
}
func ParseAttrs(attrs []string) *GeoAttributeList {
al := new(GeoAttributeList)
for _, attr := range attrs {
lc := strings.ToLower(attr)
al.Matcher = append(al.Matcher, GeoBooleanMatcher(lc))
}
return al
}
type ProcessNameMatcher struct {
ProcessNames []string
AbsPaths []string
@@ -439,4 +397,4 @@ func (m *ProcessNameMatcher) Apply(ctx routing.Context) bool {
}
}
return false
}
}

View File

@@ -1,17 +1,40 @@
package router_test
import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/infra/conf"
"github.com/xtls/xray-core/common/platform"
"github.com/xtls/xray-core/common/platform/filesystem"
"google.golang.org/protobuf/proto"
)
func getAssetPath(file string) (string, error) {
path := platform.GetAssetLocation(file)
_, err := os.Stat(path)
if os.IsNotExist(err) {
path := filepath.Join("..", "..", "resources", file)
_, err := os.Stat(path)
if os.IsNotExist(err) {
return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file)
}
if err != nil {
return "", fmt.Errorf("can't stat %s: %v", path, err)
}
return path, nil
}
if err != nil {
return "", fmt.Errorf("can't stat %s: %v", path, err)
}
return path, nil
}
func TestGeoIPMatcher(t *testing.T) {
cidrList := []*router.CIDR{
{Ip: []byte{0, 0, 0, 0}, Prefix: 8},
@@ -159,11 +182,12 @@ func TestGeoIPReverseMatcher(t *testing.T) {
}
func TestGeoIPMatcher4CN(t *testing.T) {
geo := "geoip:cn"
geoip, err := loadGeoIP(geo)
ips, err := loadGeoIP("CN")
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
common.Must(err)
if matcher.Match([]byte{8, 8, 8, 8}) {
@@ -172,11 +196,12 @@ func TestGeoIPMatcher4CN(t *testing.T) {
}
func TestGeoIPMatcher6US(t *testing.T) {
geo := "geoip:us"
geoip, err := loadGeoIP(geo)
ips, err := loadGeoIP("US")
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
common.Must(err)
if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) {
@@ -184,34 +209,37 @@ func TestGeoIPMatcher6US(t *testing.T) {
}
}
func loadGeoIP(geo string) (*router.GeoIP, error) {
os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
geoip, err := conf.ToCidrList([]string{geo})
func loadGeoIP(country string) ([]*router.CIDR, error) {
path, err := getAssetPath("geoip.dat")
if err != nil {
return nil, err
}
geoipBytes, err := filesystem.ReadFile(path)
if err != nil {
return nil, err
}
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
geoip, err = router.GetGeoIPList(geoip)
if err != nil {
return nil, err
var geoipList router.GeoIPList
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
return nil, err
}
for _, geoip := range geoipList.Entry {
if geoip.CountryCode == country {
return geoip.Cidr, nil
}
}
if len(geoip) == 0 {
panic("country not found: " + geo)
}
return geoip[0], nil
panic("country not found: " + country)
}
func BenchmarkGeoIPMatcher4CN(b *testing.B) {
geo := "geoip:cn"
geoip, err := loadGeoIP(geo)
ips, err := loadGeoIP("CN")
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
common.Must(err)
b.ResetTimer()
@@ -222,11 +250,12 @@ func BenchmarkGeoIPMatcher4CN(b *testing.B) {
}
func BenchmarkGeoIPMatcher6US(b *testing.B) {
geo := "geoip:us"
geoip, err := loadGeoIP(geo)
ips, err := loadGeoIP("US")
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
common.Must(err)
b.ResetTimer()

View File

@@ -1,22 +1,20 @@
package router_test
import (
"os"
"path/filepath"
"runtime"
"strconv"
"testing"
"github.com/xtls/xray-core/app/router"
. "github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/features/routing"
routing_session "github.com/xtls/xray-core/features/routing/session"
"github.com/xtls/xray-core/infra/conf"
"google.golang.org/protobuf/proto"
)
func withBackground() routing.Context {
@@ -302,25 +300,32 @@ func TestRoutingRule(t *testing.T) {
}
}
func loadGeoSiteDomains(geo string) ([]*Domain, error) {
os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
domains, err := conf.ParseDomainRule(geo)
func loadGeoSite(country string) ([]*Domain, error) {
path, err := getAssetPath("geosite.dat")
if err != nil {
return nil, err
}
geositeBytes, err := filesystem.ReadFile(path)
if err != nil {
return nil, err
}
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
domains, err = router.GetDomainList(domains)
if err != nil {
return nil, err
var geositeList GeoSiteList
if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil {
return nil, err
}
for _, site := range geositeList.Entry {
if site.CountryCode == country {
return site.Domain, nil
}
}
return domains, nil
return nil, errors.New("country not found: " + country)
}
func TestChinaSites(t *testing.T) {
domains, err := loadGeoSiteDomains("geosite:cn")
domains, err := loadGeoSite("CN")
common.Must(err)
acMatcher, err := NewMphMatcherGroup(domains)
@@ -361,50 +366,8 @@ func TestChinaSites(t *testing.T) {
}
}
func TestChinaSitesWithAttrs(t *testing.T) {
domains, err := loadGeoSiteDomains("geosite:google@cn")
common.Must(err)
acMatcher, err := NewMphMatcherGroup(domains)
common.Must(err)
type TestCase struct {
Domain string
Output bool
}
testCases := []TestCase{
{
Domain: "google.cn",
Output: true,
},
{
Domain: "recaptcha.net",
Output: true,
},
{
Domain: "164.com",
Output: false,
},
{
Domain: "164.com",
Output: false,
},
}
for i := 0; i < 1024; i++ {
testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false})
}
for _, testCase := range testCases {
r := acMatcher.ApplyDomain(testCase.Domain)
if r != testCase.Output {
t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r)
}
}
}
func BenchmarkMphDomainMatcher(b *testing.B) {
domains, err := loadGeoSiteDomains("geosite:cn")
domains, err := loadGeoSite("CN")
common.Must(err)
matcher, err := NewMphMatcherGroup(domains)
@@ -449,11 +412,11 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
var geoips []*GeoIP
{
ips, err := loadGeoIP("geoip:cn")
ips, err := loadGeoIP("CN")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "CN",
Cidr: ips.Cidr,
Cidr: ips,
})
}
@@ -462,25 +425,25 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "JP",
Cidr: ips.Cidr,
Cidr: ips,
})
}
{
ips, err := loadGeoIP("geoip:ca")
ips, err := loadGeoIP("CA")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "CA",
Cidr: ips.Cidr,
Cidr: ips,
})
}
{
ips, err := loadGeoIP("geoip:us")
ips, err := loadGeoIP("US")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "US",
Cidr: ips.Cidr,
Cidr: ips,
})
}

View File

@@ -3,14 +3,11 @@ package router
import (
"context"
"regexp"
"runtime"
"strings"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/features/routing"
"google.golang.org/protobuf/proto"
)
type Rule struct {
@@ -76,15 +73,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
}
if len(rr.Geoip) > 0 {
geoip := rr.Geoip
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
var err error
geoip, err = GetGeoIPList(rr.Geoip)
if err != nil {
return nil, errors.New("failed to build geoip from mmap").Base(err)
}
}
cond, err := NewIPMatcher(geoip, MatcherAsType_Target)
cond, err := NewIPMatcher(rr.Geoip, MatcherAsType_Target)
if err != nil {
return nil, err
}
@@ -109,20 +98,11 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
}
if len(rr.Domain) > 0 {
domains := rr.Domain
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
var err error
domains, err = GetDomainList(rr.Domain)
if err != nil {
return nil, errors.New("failed to build domains from mmap").Base(err)
}
}
matcher, err := NewMphMatcherGroup(domains)
matcher, err := NewMphMatcherGroup(rr.Domain)
if err != nil {
return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err)
}
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(domains), " domain rule(s)")
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)")
conds.Add(matcher)
}
@@ -183,80 +163,3 @@ func (br *BalancingRule) Build(ohm outbound.Manager, dispatcher routing.Dispatch
return nil, errors.New("unrecognized balancer type")
}
}
func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
geoipList := []*GeoIP{}
for _, ip := range ips {
if ip.CountryCode != "" {
val := strings.Split(ip.CountryCode, "_")
fileName := "geoip.dat"
if len(val) == 2 {
fileName = strings.ToLower(val[0])
}
bs, err := filesystem.ReadAsset(fileName)
if err != nil {
return nil, errors.New("failed to load file: ", fileName).Base(err)
}
bs = filesystem.Find(bs, []byte(ip.CountryCode))
var geoip GeoIP
if err := proto.Unmarshal(bs, &geoip); err != nil {
return nil, errors.New("failed Unmarshal :").Base(err)
}
geoipList = append(geoipList, &geoip)
} else {
geoipList = append(geoipList, ip)
}
}
return geoipList, nil
}
func GetDomainList(domains []*Domain) ([]*Domain, error) {
domainList := []*Domain{}
for _, domain := range domains {
val := strings.Split(domain.Value, "_")
if len(val) >= 2 {
fileName := val[0]
code := val[1]
bs, err := filesystem.ReadAsset(fileName)
if err != nil {
return nil, errors.New("failed to load file: ", fileName).Base(err)
}
bs = filesystem.Find(bs, []byte(code))
var geosite GeoSite
if err := proto.Unmarshal(bs, &geosite); err != nil {
return nil, errors.New("failed Unmarshal :").Base(err)
}
// parse attr
if len(val) == 3 {
siteWithAttr := strings.Split(val[2], ",")
attrs := ParseAttrs(siteWithAttr)
if !attrs.IsEmpty() {
filteredDomains := make([]*Domain, 0, len(domains))
for _, domain := range geosite.Domain {
if attrs.Match(domain) {
filteredDomains = append(filteredDomains, domain)
}
}
geosite.Domain = filteredDomains
}
}
domainList = append(domainList, geosite.Domain...)
} else {
domainList = append(domainList, domain)
}
}
return domainList, nil
}

View File

@@ -1,52 +0,0 @@
package filesystem
func DecodeVarint(buf []byte) (x uint64, n int) {
for shift := uint(0); shift < 64; shift += 7 {
if n >= len(buf) {
return 0, 0
}
b := uint64(buf[n])
n++
x |= (b & 0x7F) << shift
if (b & 0x80) == 0 {
return x, n
}
}
// The number is too large to represent in a 64-bit value.
return 0, 0
}
func Find(data, code []byte) []byte {
codeL := len(code)
if codeL == 0 {
return nil
}
for {
dataL := len(data)
if dataL < 2 {
return nil
}
x, y := DecodeVarint(data[1:])
if x == 0 && y == 0 {
return nil
}
headL, bodyL := 1+y, int(x)
dataL -= headL
if dataL < bodyL {
return nil
}
data = data[headL:]
if int(data[1]) == codeL {
for i := 0; i < codeL && data[2+i] == code[i]; i++ {
if i+1 == codeL {
return data[:bodyL]
}
}
}
if dataL == bodyL {
return nil
}
data = data[bodyL:]
}
}

View File

@@ -1,12 +1,9 @@
//go:build !windows && !wasm
package filesystem
import (
"io"
"os"
"path/filepath"
"syscall"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/platform"
@@ -19,29 +16,6 @@ var NewFileReader FileReaderFunc = func(path string) (io.ReadCloser, error) {
}
func ReadFile(path string) ([]byte, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return nil, err
}
size := stat.Size()
if size == 0 {
return []byte{}, nil
}
// use mmap to save RAM
bs, err := syscall.Mmap(int(file.Fd()), 0, int(size), syscall.PROT_READ, syscall.MAP_SHARED)
if err == nil {
return bs, nil
}
// fallback
reader, err := NewFileReader(path)
if err != nil {
return nil, err

View File

@@ -1,54 +0,0 @@
//go:build windows || wasm
package filesystem
import (
"io"
"os"
"path/filepath"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/platform"
)
type FileReaderFunc func(path string) (io.ReadCloser, error)
var NewFileReader FileReaderFunc = func(path string) (io.ReadCloser, error) {
return os.Open(path)
}
func ReadFile(path string) ([]byte, error) {
reader, err := NewFileReader(path)
if err != nil {
return nil, err
}
defer reader.Close()
return buf.ReadAllToBytes(reader)
}
func ReadAsset(file string) ([]byte, error) {
return ReadFile(platform.GetAssetLocation(file))
}
func ReadCert(file string) ([]byte, error) {
if filepath.IsAbs(file) {
return ReadFile(file)
}
return ReadFile(platform.GetCertLocation(file))
}
func CopyFile(dst string, src string) error {
bytes, err := ReadFile(src)
if err != nil {
return err
}
f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
}
defer f.Close()
_, err = f.Write(bytes)
return err
}

View File

@@ -3,9 +3,7 @@
package platform
import (
"path/filepath"
)
import "path/filepath"
func LineSeparator() string {
return "\r\n"
@@ -14,7 +12,6 @@ func LineSeparator() string {
// GetAssetLocation searches for `file` in the env dir and the executable dir
func GetAssetLocation(file string) string {
assetPath := NewEnvFlag(AssetLocation).GetValue(getExecutableDir)
return filepath.Join(assetPath, file)
}

View File

@@ -80,6 +80,21 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error {
return errors.New("failed to parse name server: ", string(data))
}
func toDomainMatchingType(t router.Domain_Type) dns.DomainMatchingType {
switch t {
case router.Domain_Domain:
return dns.DomainMatchingType_Subdomain
case router.Domain_Full:
return dns.DomainMatchingType_Full
case router.Domain_Plain:
return dns.DomainMatchingType_Keyword
case router.Domain_Regex:
return dns.DomainMatchingType_Regex
default:
panic("unknown domain type")
}
}
func (c *NameServerConfig) Build() (*dns.NameServer, error) {
if c.Address == nil {
return nil, errors.New("NameServer address is not specified.")
@@ -89,14 +104,14 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) {
var originalRules []*dns.NameServer_OriginalRule
for _, rule := range c.Domains {
parsedDomain, err := ParseDomainRule(rule)
parsedDomain, err := parseDomainRule(rule)
if err != nil {
return nil, errors.New("invalid domain rule: ", rule).Base(err)
}
for _, pd := range parsedDomain {
domains = append(domains, &dns.NameServer_PriorityDomain{
Type: dns.ToDomainMatchingType(pd.Type),
Type: toDomainMatchingType(pd.Type),
Domain: pd.Value,
})
}

View File

@@ -203,23 +203,17 @@ func loadFile(file string) ([]byte, error) {
func loadIP(file, code string) ([]*router.CIDR, error) {
index := file + ":" + code
if IPCache[index] == nil {
bs, err := loadFile(file)
if err != nil {
return nil, errors.New("failed to load file: ", file).Base(err)
}
bs = find(bs, []byte(code))
if bs == nil {
return nil, errors.New("code not found in ", file, ": ", code)
}
var geoip router.GeoIP
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
// dont pass code becuase we have country code in top level router.GeoIP
geoip = router.GeoIP{Cidr: []*router.CIDR{}}
} else {
bs, err := loadFile(file)
if err != nil {
return nil, errors.New("failed to load file: ", file).Base(err)
}
bs = filesystem.Find(bs, []byte(code))
if bs == nil {
return nil, errors.New("code not found in ", file, ": ", code)
}
if err := proto.Unmarshal(bs, &geoip); err != nil {
return nil, errors.New("error unmarshal IP in ", file, ": ", code).Base(err)
}
if err := proto.Unmarshal(bs, &geoip); err != nil {
return nil, errors.New("error unmarshal IP in ", file, ": ", code).Base(err)
}
defer runtime.GC() // or debug.FreeOSMemory()
return geoip.Cidr, nil // do not cache geoip
@@ -231,28 +225,18 @@ func loadIP(file, code string) ([]*router.CIDR, error) {
func loadSite(file, code string) ([]*router.Domain, error) {
index := file + ":" + code
if SiteCache[index] == nil {
var geosite router.GeoSite
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
// pass file:code so can build optimized matcher later
domain := router.Domain{Value: file + "_" + code}
geosite = router.GeoSite{Domain: []*router.Domain{&domain}}
} else {
bs, err := loadFile(file)
if err != nil {
return nil, errors.New("failed to load file: ", file).Base(err)
}
bs = filesystem.Find(bs, []byte(code))
if bs == nil {
return nil, errors.New("list not found in ", file, ": ", code)
}
if err := proto.Unmarshal(bs, &geosite); err != nil {
return nil, errors.New("error unmarshal Site in ", file, ": ", code).Base(err)
}
bs, err := loadFile(file)
if err != nil {
return nil, errors.New("failed to load file: ", file).Base(err)
}
bs = find(bs, []byte(code))
if bs == nil {
return nil, errors.New("list not found in ", file, ": ", code)
}
var geosite router.GeoSite
if err := proto.Unmarshal(bs, &geosite); err != nil {
return nil, errors.New("error unmarshal Site in ", file, ": ", code).Base(err)
}
defer runtime.GC() // or debug.FreeOSMemory()
return geosite.Domain, nil // do not cache geosite
SiteCache[index] = &geosite
@@ -260,13 +244,105 @@ func loadSite(file, code string) ([]*router.Domain, error) {
return SiteCache[index].Domain, nil
}
func DecodeVarint(buf []byte) (x uint64, n int) {
for shift := uint(0); shift < 64; shift += 7 {
if n >= len(buf) {
return 0, 0
}
b := uint64(buf[n])
n++
x |= (b & 0x7F) << shift
if (b & 0x80) == 0 {
return x, n
}
}
// The number is too large to represent in a 64-bit value.
return 0, 0
}
func find(data, code []byte) []byte {
codeL := len(code)
if codeL == 0 {
return nil
}
for {
dataL := len(data)
if dataL < 2 {
return nil
}
x, y := DecodeVarint(data[1:])
if x == 0 && y == 0 {
return nil
}
headL, bodyL := 1+y, int(x)
dataL -= headL
if dataL < bodyL {
return nil
}
data = data[headL:]
if int(data[1]) == codeL {
for i := 0; i < codeL && data[2+i] == code[i]; i++ {
if i+1 == codeL {
return data[:bodyL]
}
}
}
if dataL == bodyL {
return nil
}
data = data[bodyL:]
}
}
type AttributeMatcher interface {
Match(*router.Domain) bool
}
type BooleanMatcher string
func (m BooleanMatcher) Match(domain *router.Domain) bool {
for _, attr := range domain.Attribute {
if attr.Key == string(m) {
return true
}
}
return false
}
type AttributeList struct {
matcher []AttributeMatcher
}
func (al *AttributeList) Match(domain *router.Domain) bool {
for _, matcher := range al.matcher {
if !matcher.Match(domain) {
return false
}
}
return true
}
func (al *AttributeList) IsEmpty() bool {
return len(al.matcher) == 0
}
func parseAttrs(attrs []string) *AttributeList {
al := new(AttributeList)
for _, attr := range attrs {
lc := strings.ToLower(attr)
al.matcher = append(al.matcher, BooleanMatcher(lc))
}
return al
}
func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, error) {
parts := strings.Split(siteWithAttr, "@")
if len(parts) == 0 {
return nil, errors.New("empty site")
}
country := strings.ToUpper(parts[0])
attrs := router.ParseAttrs(parts[1:])
attrs := parseAttrs(parts[1:])
domains, err := loadSite(file, country)
if err != nil {
return nil, err
@@ -276,11 +352,6 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
return domains, nil
}
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
domains[0].Value = domains[0].Value + "_" + strings.Join(parts[1:], ",")
return domains, nil
}
filteredDomains := make([]*router.Domain, 0, len(domains))
for _, domain := range domains {
if attrs.Match(domain) {
@@ -291,7 +362,7 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
return filteredDomains, nil
}
func ParseDomainRule(domain string) ([]*router.Domain, error) {
func parseDomainRule(domain string) ([]*router.Domain, error) {
if strings.HasPrefix(domain, "geosite:") {
country := strings.ToUpper(domain[8:])
domains, err := loadGeositeWithAttr("geosite.dat", country)
@@ -489,7 +560,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
if rawFieldRule.Domain != nil {
for _, domain := range *rawFieldRule.Domain {
rules, err := ParseDomainRule(domain)
rules, err := parseDomainRule(domain)
if err != nil {
return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
}
@@ -499,7 +570,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
if rawFieldRule.Domains != nil {
for _, domain := range *rawFieldRule.Domains {
rules, err := ParseDomainRule(domain)
rules, err := parseDomainRule(domain)
if err != nil {
return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
}

View File

@@ -89,14 +89,13 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
}
}()
r, ok := <-bind.readQueue
if !ok {
return 0, errors.New("channel closed")
r := &netReadInfo{
buff: bufs[0],
}
copy(bufs[0], r.buff[:r.bytes])
r.waiter.Add(1)
bind.readQueue <- r
r.waiter.Wait() // wait read goroutine done, or we will miss the result
sizes[0], eps[0] = r.bytes, r.endpoint
r.waiter.Done()
return 1, r.err
}
workers := bind.workers
@@ -134,29 +133,24 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
}
endpoint.conn = c
go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) {
defer func() {
_ = recover() // handle send on closed channel
}()
go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
for {
buff := make([]byte, 1700)
i, err := c.Read(buff)
v, ok := <-readQueue
if !ok {
return
}
i, err := c.Read(v.buff)
if i > 3 {
buff[1] = 0
buff[2] = 0
buff[3] = 0
v.buff[1] = 0
v.buff[2] = 0
v.buff[3] = 0
}
r := &netReadInfo{
buff: buff,
bytes: i,
endpoint: endpoint,
err: err,
}
r.waiter.Add(1)
readQueue <- r
r.waiter.Wait()
v.bytes = i
v.endpoint = endpoint
v.err = err
v.waiter.Done()
if err != nil {
endpoint.conn = nil
return

View File

@@ -1,12 +1,15 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"testing"
"github.com/stretchr/testify/assert"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/protocol/tls/cert"
)
func TestCalculateCertHash(t *testing.T) {
@@ -95,3 +98,60 @@ uI6HqHFD3iEct8fBkYfQiwH2e1eu9OwgujiWHsutyK8VvzVB3/YnhQ/TzciRjPqz
assert.Equal(t, fingerprint, hash)
})
}
func TestVerifyPeerLeafCert(t *testing.T) {
leafCert := cert.MustGenerate(nil, cert.DNSNames("example.com"))
leaf := common.Must2(x509.ParseCertificate(leafCert.Certificate))
caHash := GenerateCertHash(leafCert.Certificate)
r := &RandCarrier{
Config: &tls.Config{
ServerName: "example.com",
},
PinnedPeerCertSha256: [][]byte{caHash},
}
rawCerts := [][]byte{leaf.Raw}
err := r.verifyPeerCert(rawCerts, nil)
if err != nil {
t.Fatal("expected to verify leaf cert signed by pinned CA, but got error:", err)
}
// make the pinned hash incorrect
r.PinnedPeerCertSha256[0][0] += 1
err = r.verifyPeerCert(rawCerts, nil)
if err == nil {
t.Fatal("expected to fail verifying leaf cert with incorrect pinned CA hash, but got no error")
}
}
func TestVerifyPeerCACert(t *testing.T) {
caCert := cert.MustGenerate(nil, cert.Authority(true), cert.KeyUsage(x509.KeyUsageCertSign))
ca := common.Must2(x509.ParseCertificate(caCert.Certificate))
leafCert := cert.MustGenerate(caCert, cert.DNSNames("example.com"))
leaf := common.Must2(x509.ParseCertificate(leafCert.Certificate))
caHash := GenerateCertHash(ca)
r := &RandCarrier{
Config: &tls.Config{
ServerName: "example.com",
},
PinnedPeerCertSha256: [][]byte{caHash},
}
rawCerts := [][]byte{leaf.Raw, ca.Raw}
err := r.verifyPeerCert(rawCerts, nil)
if err != nil {
t.Fatal("expected to verify leaf cert signed by pinned CA, but got error:", err)
}
// make the pinned hash incorrect
r.PinnedPeerCertSha256[0][0] += 1
err = r.verifyPeerCert(rawCerts, nil)
if err == nil {
t.Fatal("expected to fail verifying leaf cert with incorrect pinned CA hash, but got no error")
}
}