Compare commits

...

9 Commits

Author SHA1 Message Date
Fangliding
0fc258edd3 Revert "Geofiles: Implement mmap in filesystem to reduce ram usage (#5480)"
This reverts commit 5d94a62a83.
2026-01-17 20:14:01 +08:00
Fangliding
1fc5b7aef5 Revert geofile change
This reverts commit 36425d2a6e.

This reverts commit 961c352127.

This reverts commit c715154309.
2026-01-17 20:10:33 +08:00
风扇滑翔翼
09f619d67c TLS client: Add pin_test.go for leaf and CA (#5553)
https://github.com/XTLS/Xray-core/pull/5532#issuecomment-3760231005
2026-01-17 09:42:06 +00:00
RPRX
760223ad70 TLS client: Skip TLS' built-in verification when using pinnedPeerCertSha256; Fixes
https://github.com/XTLS/Xray-core/pull/5532#issuecomment-3745598515
https://github.com/XTLS/Xray-core/pull/5532#issuecomment-3759930283
https://github.com/XTLS/Xray-core/pull/5532#issuecomment-3760057266
https://github.com/XTLS/Xray-core/pull/5532#issuecomment-3760540231
2026-01-16 15:23:39 +00:00
nobody
d75b33a3a3 Commands: "xray run -dump" supports reading JSON from STDIN (#5550)
Fixes https://github.com/XTLS/Xray-core/issues/5534
2026-01-16 10:55:43 +00:00
Happ-dev
7c418486c8 README.md: Add Happ RU to iOS & macOS Clients (#5551) 2026-01-16 10:44:44 +00:00
风扇滑翔翼
a384be0f84 SS2022 outbound: Fix UDP leak (#5544)
Fixes https://github.com/XTLS/Xray-core/issues/5541
2026-01-14 13:31:43 +00:00
LjhAUMEM
649e989fa2 Hysteria: Fix transport's "udphop without salamander" dialing issue; Require "version": 2 in outbound's settings as well (#5537)
Updated example: https://github.com/XTLS/Xray-core/pull/5508#issue-3795798712
2026-01-14 10:42:07 +00:00
dependabot[bot]
0443de7798 Bump github.com/refraction-networking/utls from 1.8.1 to 1.8.2 (#5535)
Bumps [github.com/refraction-networking/utls](https://github.com/refraction-networking/utls) from 1.8.1 to 1.8.2.
- [Release notes](https://github.com/refraction-networking/utls/releases)
- [Commits](https://github.com/refraction-networking/utls/compare/v1.8.1...v1.8.2)

---
updated-dependencies:
- dependency-name: github.com/refraction-networking/utls
  dependency-version: 1.8.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-14 08:24:12 +00:00
25 changed files with 388 additions and 564 deletions

View File

@@ -112,11 +112,11 @@
- [SimpleXray](https://github.com/lhear/SimpleXray)
- [AnyPortal](https://github.com/AnyPortal/AnyPortal)
- iOS & macOS arm64 & tvOS
- [Happ](https://apps.apple.com/app/happ-proxy-utility/id6504287215) ([tvOS](https://apps.apple.com/us/app/happ-proxy-utility-for-tv/id6748297274))
- [Happ](https://apps.apple.com/app/happ-proxy-utility/id6504287215) | [Happ RU](https://apps.apple.com/ru/app/happ-proxy-utility-plus/id6746188973) | [Happ tvOS](https://apps.apple.com/us/app/happ-proxy-utility-for-tv/id6748297274)
- [Streisand](https://apps.apple.com/app/streisand/id6450534064)
- [OneXray](https://github.com/OneXray/OneXray)
- macOS arm64 & x64
- [Happ](https://apps.apple.com/app/happ-proxy-utility/id6504287215)
- [Happ](https://apps.apple.com/app/happ-proxy-utility/id6504287215) | [Happ RU](https://apps.apple.com/ru/app/happ-proxy-utility-plus/id6746188973)
- [V2rayU](https://github.com/yanue/V2rayU)
- [V2RayXS](https://github.com/tzmax/V2RayXS)
- [Furious](https://github.com/LorenEteval/Furious)

View File

@@ -12,15 +12,12 @@ import (
"sync"
"time"
router "github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/strmatcher"
"github.com/xtls/xray-core/features/dns"
"google.golang.org/protobuf/proto"
)
// DNS is a DNS rely server.
@@ -100,25 +97,6 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
}
for _, ns := range config.NameServer {
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
err := parseDomains(ns)
if err != nil {
return nil, errors.New("failed to parse dns domain rules: ").Base(err)
}
expectedGeoip, err := router.GetGeoIPList(ns.ExpectedGeoip)
if err != nil {
return nil, errors.New("failed to parse dns expectIPs rules: ").Base(err)
}
ns.ExpectedGeoip = expectedGeoip
unexpectedGeoip, err := router.GetGeoIPList(ns.UnexpectedGeoip)
if err != nil {
return nil, errors.New("failed to parse dns unexpectedGeoip rules: ").Base(err)
}
ns.UnexpectedGeoip = unexpectedGeoip
}
domainRuleCount += len(ns.PrioritizedDomain)
}
@@ -602,76 +580,3 @@ func detectGUIPlatform() bool {
}
return false
}
func parseDomains(ns *NameServer) error {
pureDomains := []*router.Domain{}
// convert to pure domain
for _, pd := range ns.PrioritizedDomain {
pureDomains = append(pureDomains, &router.Domain{
Type: router.Domain_Type(pd.Type),
Value: pd.Domain,
})
}
domainList := []*router.Domain{}
for _, domain := range pureDomains {
val := strings.Split(domain.Value, "_")
if len(val) >= 2 {
fileName := val[0]
code := val[1]
bs, err := filesystem.ReadAsset(fileName)
if err != nil {
return errors.New("failed to load file: ", fileName).Base(err)
}
bs = filesystem.Find(bs, []byte(code))
var geosite router.GeoSite
if err := proto.Unmarshal(bs, &geosite); err != nil {
return errors.New("failed Unmarshal :").Base(err)
}
// parse attr
if len(val) == 3 {
siteWithAttr := strings.Split(val[2], ",")
attrs := router.ParseAttrs(siteWithAttr)
if !attrs.IsEmpty() {
filteredDomains := make([]*router.Domain, 0, len(pureDomains))
for _, domain := range geosite.Domain {
if attrs.Match(domain) {
filteredDomains = append(filteredDomains, domain)
}
}
geosite.Domain = filteredDomains
}
}
domainList = append(domainList, geosite.Domain...)
// update ns.OriginalRules Size
ruleTag := strings.Join(val, ":")
for i, oRule := range ns.OriginalRules {
if oRule.Rule == strings.ToLower(ruleTag) {
ns.OriginalRules[i].Size = uint32(len(geosite.Domain))
}
}
} else {
domainList = append(domainList, domain)
}
}
// convert back to NameServer_PriorityDomain
ns.PrioritizedDomain = []*NameServer_PriorityDomain{}
for _, pd := range domainList {
ns.PrioritizedDomain = append(ns.PrioritizedDomain, &NameServer_PriorityDomain{
Type: ToDomainMatchingType(pd.Type),
Domain: pd.Value,
})
}
return nil
}

View File

@@ -541,7 +541,7 @@ func TestIPMatch(t *testing.T) {
},
ExpectedGeoip: []*router.GeoIP{
{
// local
CountryCode: "local",
Cidr: []*router.CIDR{
{
// inner ip, will not match
@@ -565,7 +565,7 @@ func TestIPMatch(t *testing.T) {
},
ExpectedGeoip: []*router.GeoIP{
{
// test
CountryCode: "test",
Cidr: []*router.CIDR{
{
Ip: []byte{8, 8, 8, 8},
@@ -574,7 +574,7 @@ func TestIPMatch(t *testing.T) {
},
},
{
// test
CountryCode: "test",
Cidr: []*router.CIDR{
{
Ip: []byte{8, 8, 8, 4},
@@ -669,7 +669,7 @@ func TestLocalDomain(t *testing.T) {
},
ExpectedGeoip: []*router.GeoIP{
{ // Will match localhost, localhost-a and localhost-b,
// local
CountryCode: "local",
Cidr: []*router.CIDR{
{Ip: []byte{127, 0, 0, 2}, Prefix: 32},
{Ip: []byte{127, 0, 0, 3}, Prefix: 32},

View File

@@ -297,18 +297,3 @@ func ResolveIpOptionOverride(queryStrategy QueryStrategy, ipOption dns.IPOption)
return ipOption
}
}
func ToDomainMatchingType(t router.Domain_Type) DomainMatchingType {
switch t {
case router.Domain_Domain:
return DomainMatchingType_Subdomain
case router.Domain_Full:
return DomainMatchingType_Full
case router.Domain_Plain:
return DomainMatchingType_Keyword
case router.Domain_Regex:
return DomainMatchingType_Regex
default:
panic("unknown domain type")
}
}

View File

@@ -309,48 +309,6 @@ func (m *AttributeMatcher) Apply(ctx routing.Context) bool {
return m.Match(attributes)
}
// Geo attribute
type GeoAttributeMatcher interface {
Match(*Domain) bool
}
type GeoBooleanMatcher string
func (m GeoBooleanMatcher) Match(domain *Domain) bool {
for _, attr := range domain.Attribute {
if attr.Key == string(m) {
return true
}
}
return false
}
type GeoAttributeList struct {
Matcher []GeoAttributeMatcher
}
func (al *GeoAttributeList) Match(domain *Domain) bool {
for _, matcher := range al.Matcher {
if !matcher.Match(domain) {
return false
}
}
return true
}
func (al *GeoAttributeList) IsEmpty() bool {
return len(al.Matcher) == 0
}
func ParseAttrs(attrs []string) *GeoAttributeList {
al := new(GeoAttributeList)
for _, attr := range attrs {
lc := strings.ToLower(attr)
al.Matcher = append(al.Matcher, GeoBooleanMatcher(lc))
}
return al
}
type ProcessNameMatcher struct {
ProcessNames []string
AbsPaths []string
@@ -439,4 +397,4 @@ func (m *ProcessNameMatcher) Apply(ctx routing.Context) bool {
}
}
return false
}
}

View File

@@ -1,17 +1,40 @@
package router_test
import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/infra/conf"
"github.com/xtls/xray-core/common/platform"
"github.com/xtls/xray-core/common/platform/filesystem"
"google.golang.org/protobuf/proto"
)
func getAssetPath(file string) (string, error) {
path := platform.GetAssetLocation(file)
_, err := os.Stat(path)
if os.IsNotExist(err) {
path := filepath.Join("..", "..", "resources", file)
_, err := os.Stat(path)
if os.IsNotExist(err) {
return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file)
}
if err != nil {
return "", fmt.Errorf("can't stat %s: %v", path, err)
}
return path, nil
}
if err != nil {
return "", fmt.Errorf("can't stat %s: %v", path, err)
}
return path, nil
}
func TestGeoIPMatcher(t *testing.T) {
cidrList := []*router.CIDR{
{Ip: []byte{0, 0, 0, 0}, Prefix: 8},
@@ -159,11 +182,12 @@ func TestGeoIPReverseMatcher(t *testing.T) {
}
func TestGeoIPMatcher4CN(t *testing.T) {
geo := "geoip:cn"
geoip, err := loadGeoIP(geo)
ips, err := loadGeoIP("CN")
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
common.Must(err)
if matcher.Match([]byte{8, 8, 8, 8}) {
@@ -172,11 +196,12 @@ func TestGeoIPMatcher4CN(t *testing.T) {
}
func TestGeoIPMatcher6US(t *testing.T) {
geo := "geoip:us"
geoip, err := loadGeoIP(geo)
ips, err := loadGeoIP("US")
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
common.Must(err)
if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) {
@@ -184,34 +209,37 @@ func TestGeoIPMatcher6US(t *testing.T) {
}
}
func loadGeoIP(geo string) (*router.GeoIP, error) {
os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
geoip, err := conf.ToCidrList([]string{geo})
func loadGeoIP(country string) ([]*router.CIDR, error) {
path, err := getAssetPath("geoip.dat")
if err != nil {
return nil, err
}
geoipBytes, err := filesystem.ReadFile(path)
if err != nil {
return nil, err
}
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
geoip, err = router.GetGeoIPList(geoip)
if err != nil {
return nil, err
var geoipList router.GeoIPList
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
return nil, err
}
for _, geoip := range geoipList.Entry {
if geoip.CountryCode == country {
return geoip.Cidr, nil
}
}
if len(geoip) == 0 {
panic("country not found: " + geo)
}
return geoip[0], nil
panic("country not found: " + country)
}
func BenchmarkGeoIPMatcher4CN(b *testing.B) {
geo := "geoip:cn"
geoip, err := loadGeoIP(geo)
ips, err := loadGeoIP("CN")
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
common.Must(err)
b.ResetTimer()
@@ -222,11 +250,12 @@ func BenchmarkGeoIPMatcher4CN(b *testing.B) {
}
func BenchmarkGeoIPMatcher6US(b *testing.B) {
geo := "geoip:us"
geoip, err := loadGeoIP(geo)
ips, err := loadGeoIP("US")
common.Must(err)
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
common.Must(err)
b.ResetTimer()

View File

@@ -1,22 +1,20 @@
package router_test
import (
"os"
"path/filepath"
"runtime"
"strconv"
"testing"
"github.com/xtls/xray-core/app/router"
. "github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/features/routing"
routing_session "github.com/xtls/xray-core/features/routing/session"
"github.com/xtls/xray-core/infra/conf"
"google.golang.org/protobuf/proto"
)
func withBackground() routing.Context {
@@ -302,25 +300,32 @@ func TestRoutingRule(t *testing.T) {
}
}
func loadGeoSiteDomains(geo string) ([]*Domain, error) {
os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
domains, err := conf.ParseDomainRule(geo)
func loadGeoSite(country string) ([]*Domain, error) {
path, err := getAssetPath("geosite.dat")
if err != nil {
return nil, err
}
geositeBytes, err := filesystem.ReadFile(path)
if err != nil {
return nil, err
}
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
domains, err = router.GetDomainList(domains)
if err != nil {
return nil, err
var geositeList GeoSiteList
if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil {
return nil, err
}
for _, site := range geositeList.Entry {
if site.CountryCode == country {
return site.Domain, nil
}
}
return domains, nil
return nil, errors.New("country not found: " + country)
}
func TestChinaSites(t *testing.T) {
domains, err := loadGeoSiteDomains("geosite:cn")
domains, err := loadGeoSite("CN")
common.Must(err)
acMatcher, err := NewMphMatcherGroup(domains)
@@ -361,50 +366,8 @@ func TestChinaSites(t *testing.T) {
}
}
func TestChinaSitesWithAttrs(t *testing.T) {
domains, err := loadGeoSiteDomains("geosite:google@cn")
common.Must(err)
acMatcher, err := NewMphMatcherGroup(domains)
common.Must(err)
type TestCase struct {
Domain string
Output bool
}
testCases := []TestCase{
{
Domain: "google.cn",
Output: true,
},
{
Domain: "recaptcha.net",
Output: true,
},
{
Domain: "164.com",
Output: false,
},
{
Domain: "164.com",
Output: false,
},
}
for i := 0; i < 1024; i++ {
testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false})
}
for _, testCase := range testCases {
r := acMatcher.ApplyDomain(testCase.Domain)
if r != testCase.Output {
t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r)
}
}
}
func BenchmarkMphDomainMatcher(b *testing.B) {
domains, err := loadGeoSiteDomains("geosite:cn")
domains, err := loadGeoSite("CN")
common.Must(err)
matcher, err := NewMphMatcherGroup(domains)
@@ -449,11 +412,11 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
var geoips []*GeoIP
{
ips, err := loadGeoIP("geoip:cn")
ips, err := loadGeoIP("CN")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "CN",
Cidr: ips.Cidr,
Cidr: ips,
})
}
@@ -462,25 +425,25 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "JP",
Cidr: ips.Cidr,
Cidr: ips,
})
}
{
ips, err := loadGeoIP("geoip:ca")
ips, err := loadGeoIP("CA")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "CA",
Cidr: ips.Cidr,
Cidr: ips,
})
}
{
ips, err := loadGeoIP("geoip:us")
ips, err := loadGeoIP("US")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "US",
Cidr: ips.Cidr,
Cidr: ips,
})
}

View File

@@ -3,14 +3,11 @@ package router
import (
"context"
"regexp"
"runtime"
"strings"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/features/routing"
"google.golang.org/protobuf/proto"
)
type Rule struct {
@@ -76,15 +73,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
}
if len(rr.Geoip) > 0 {
geoip := rr.Geoip
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
var err error
geoip, err = GetGeoIPList(rr.Geoip)
if err != nil {
return nil, errors.New("failed to build geoip from mmap").Base(err)
}
}
cond, err := NewIPMatcher(geoip, MatcherAsType_Target)
cond, err := NewIPMatcher(rr.Geoip, MatcherAsType_Target)
if err != nil {
return nil, err
}
@@ -109,20 +98,11 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
}
if len(rr.Domain) > 0 {
domains := rr.Domain
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
var err error
domains, err = GetDomainList(rr.Domain)
if err != nil {
return nil, errors.New("failed to build domains from mmap").Base(err)
}
}
matcher, err := NewMphMatcherGroup(domains)
matcher, err := NewMphMatcherGroup(rr.Domain)
if err != nil {
return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err)
}
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(domains), " domain rule(s)")
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)")
conds.Add(matcher)
}
@@ -183,80 +163,3 @@ func (br *BalancingRule) Build(ohm outbound.Manager, dispatcher routing.Dispatch
return nil, errors.New("unrecognized balancer type")
}
}
func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
geoipList := []*GeoIP{}
for _, ip := range ips {
if ip.CountryCode != "" {
val := strings.Split(ip.CountryCode, "_")
fileName := "geoip.dat"
if len(val) == 2 {
fileName = strings.ToLower(val[0])
}
bs, err := filesystem.ReadAsset(fileName)
if err != nil {
return nil, errors.New("failed to load file: ", fileName).Base(err)
}
bs = filesystem.Find(bs, []byte(ip.CountryCode))
var geoip GeoIP
if err := proto.Unmarshal(bs, &geoip); err != nil {
return nil, errors.New("failed Unmarshal :").Base(err)
}
geoipList = append(geoipList, &geoip)
} else {
geoipList = append(geoipList, ip)
}
}
return geoipList, nil
}
func GetDomainList(domains []*Domain) ([]*Domain, error) {
domainList := []*Domain{}
for _, domain := range domains {
val := strings.Split(domain.Value, "_")
if len(val) >= 2 {
fileName := val[0]
code := val[1]
bs, err := filesystem.ReadAsset(fileName)
if err != nil {
return nil, errors.New("failed to load file: ", fileName).Base(err)
}
bs = filesystem.Find(bs, []byte(code))
var geosite GeoSite
if err := proto.Unmarshal(bs, &geosite); err != nil {
return nil, errors.New("failed Unmarshal :").Base(err)
}
// parse attr
if len(val) == 3 {
siteWithAttr := strings.Split(val[2], ",")
attrs := ParseAttrs(siteWithAttr)
if !attrs.IsEmpty() {
filteredDomains := make([]*Domain, 0, len(domains))
for _, domain := range geosite.Domain {
if attrs.Match(domain) {
filteredDomains = append(filteredDomains, domain)
}
}
geosite.Domain = filteredDomains
}
}
domainList = append(domainList, geosite.Domain...)
} else {
domainList = append(domainList, domain)
}
}
return domainList, nil
}

View File

@@ -1,52 +0,0 @@
package filesystem
func DecodeVarint(buf []byte) (x uint64, n int) {
for shift := uint(0); shift < 64; shift += 7 {
if n >= len(buf) {
return 0, 0
}
b := uint64(buf[n])
n++
x |= (b & 0x7F) << shift
if (b & 0x80) == 0 {
return x, n
}
}
// The number is too large to represent in a 64-bit value.
return 0, 0
}
func Find(data, code []byte) []byte {
codeL := len(code)
if codeL == 0 {
return nil
}
for {
dataL := len(data)
if dataL < 2 {
return nil
}
x, y := DecodeVarint(data[1:])
if x == 0 && y == 0 {
return nil
}
headL, bodyL := 1+y, int(x)
dataL -= headL
if dataL < bodyL {
return nil
}
data = data[headL:]
if int(data[1]) == codeL {
for i := 0; i < codeL && data[2+i] == code[i]; i++ {
if i+1 == codeL {
return data[:bodyL]
}
}
}
if dataL == bodyL {
return nil
}
data = data[bodyL:]
}
}

View File

@@ -1,12 +1,9 @@
//go:build !windows && !wasm
package filesystem
import (
"io"
"os"
"path/filepath"
"syscall"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/platform"
@@ -19,29 +16,6 @@ var NewFileReader FileReaderFunc = func(path string) (io.ReadCloser, error) {
}
func ReadFile(path string) ([]byte, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return nil, err
}
size := stat.Size()
if size == 0 {
return []byte{}, nil
}
// use mmap to save RAM
bs, err := syscall.Mmap(int(file.Fd()), 0, int(size), syscall.PROT_READ, syscall.MAP_SHARED)
if err == nil {
return bs, nil
}
// fallback
reader, err := NewFileReader(path)
if err != nil {
return nil, err

View File

@@ -1,54 +0,0 @@
//go:build windows || wasm
package filesystem
import (
"io"
"os"
"path/filepath"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/platform"
)
type FileReaderFunc func(path string) (io.ReadCloser, error)
var NewFileReader FileReaderFunc = func(path string) (io.ReadCloser, error) {
return os.Open(path)
}
func ReadFile(path string) ([]byte, error) {
reader, err := NewFileReader(path)
if err != nil {
return nil, err
}
defer reader.Close()
return buf.ReadAllToBytes(reader)
}
func ReadAsset(file string) ([]byte, error) {
return ReadFile(platform.GetAssetLocation(file))
}
func ReadCert(file string) ([]byte, error) {
if filepath.IsAbs(file) {
return ReadFile(file)
}
return ReadFile(platform.GetCertLocation(file))
}
func CopyFile(dst string, src string) error {
bytes, err := ReadFile(src)
if err != nil {
return err
}
f, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
}
defer f.Close()
_, err = f.Write(bytes)
return err
}

View File

@@ -3,9 +3,7 @@
package platform
import (
"path/filepath"
)
import "path/filepath"
func LineSeparator() string {
return "\r\n"
@@ -14,7 +12,6 @@ func LineSeparator() string {
// GetAssetLocation searches for `file` in the env dir and the executable dir
func GetAssetLocation(file string) string {
assetPath := NewEnvFlag(AssetLocation).GetValue(getExecutableDir)
return filepath.Join(assetPath, file)
}

View File

@@ -2,10 +2,12 @@ package singbridge
import (
"context"
"time"
B "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/transport"
@@ -29,6 +31,8 @@ type PacketConnWrapper struct {
cached buf.MultiBuffer
}
// This ReadPacket implemented a timeout to avoid goroutine leak like PipeConnWrapper.Read()
// as a temporarily solution
func (w *PacketConnWrapper) ReadPacket(buffer *B.Buffer) (M.Socksaddr, error) {
if w.cached != nil {
mb, bb := buf.SplitFirst(w.cached)
@@ -47,10 +51,30 @@ func (w *PacketConnWrapper) ReadPacket(buffer *B.Buffer) (M.Socksaddr, error) {
return ToSocksaddr(destination), nil
}
}
mb, err := w.ReadMultiBuffer()
if err != nil {
return M.Socksaddr{}, err
// timeout
type readResult struct {
mb buf.MultiBuffer
err error
}
c := make(chan readResult, 1)
go func() {
mb, err := w.ReadMultiBuffer()
c <- readResult{mb: mb, err: err}
}()
var mb buf.MultiBuffer
select {
case <-time.After(60 * time.Second):
common.Close(w.Reader)
common.Interrupt(w.Reader)
return M.Socksaddr{}, buf.ErrReadTimeout
case result := <-c:
if result.err != nil {
return M.Socksaddr{}, result.err
}
mb = result.mb
}
nb, bb := buf.SplitFirst(mb)
if bb == nil {
return M.Socksaddr{}, nil

View File

@@ -64,7 +64,11 @@ func GetMergedConfig(args cmdarg.Arg) (string, error) {
var files []*ConfigSource
supported := []string{"json", "yaml", "toml"}
for _, file := range args {
format := GetFormat(file)
format := "json"
if file != "stdin:" {
format = GetFormat(file)
}
if slices.Contains(supported, format) {
files = append(files, &ConfigSource{
Name: file,

2
go.mod
View File

@@ -12,7 +12,7 @@ require (
github.com/miekg/dns v1.1.70
github.com/pelletier/go-toml v1.9.5
github.com/pires/go-proxyproto v0.8.1
github.com/refraction-networking/utls v1.8.1
github.com/refraction-networking/utls v1.8.2
github.com/sagernet/sing v0.5.1
github.com/sagernet/sing-shadowsocks v0.2.7
github.com/seiflotfy/cuckoofilter v0.0.0-20240715131351-a2f2c23f1771

4
go.sum
View File

@@ -52,8 +52,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg=
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3/go.mod h1:HgjTstvQsPGkxUsCd2KWxErBblirPizecHcpD3ffK+s=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=

View File

@@ -80,6 +80,21 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error {
return errors.New("failed to parse name server: ", string(data))
}
func toDomainMatchingType(t router.Domain_Type) dns.DomainMatchingType {
switch t {
case router.Domain_Domain:
return dns.DomainMatchingType_Subdomain
case router.Domain_Full:
return dns.DomainMatchingType_Full
case router.Domain_Plain:
return dns.DomainMatchingType_Keyword
case router.Domain_Regex:
return dns.DomainMatchingType_Regex
default:
panic("unknown domain type")
}
}
func (c *NameServerConfig) Build() (*dns.NameServer, error) {
if c.Address == nil {
return nil, errors.New("NameServer address is not specified.")
@@ -89,14 +104,14 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) {
var originalRules []*dns.NameServer_OriginalRule
for _, rule := range c.Domains {
parsedDomain, err := ParseDomainRule(rule)
parsedDomain, err := parseDomainRule(rule)
if err != nil {
return nil, errors.New("invalid domain rule: ", rule).Base(err)
}
for _, pd := range parsedDomain {
domains = append(domains, &dns.NameServer_PriorityDomain{
Type: dns.ToDomainMatchingType(pd.Type),
Type: toDomainMatchingType(pd.Type),
Domain: pd.Value,
})
}

View File

@@ -1,19 +1,25 @@
package conf
import (
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/proxy/hysteria"
"google.golang.org/protobuf/proto"
)
type HysteriaClientConfig struct {
Version int32 `json:"version"`
Address *Address `json:"address"`
Port uint16 `json:"port"`
}
func (c *HysteriaClientConfig) Build() (proto.Message, error) {
config := new(hysteria.ClientConfig)
if c.Version != 2 {
return nil, errors.New("version != 2")
}
config := &hysteria.ClientConfig{}
config.Version = c.Version
config.Server = &protocol.ServerEndpoint{
Address: c.Address.Build(),
Port: uint32(c.Port),

View File

@@ -203,23 +203,17 @@ func loadFile(file string) ([]byte, error) {
func loadIP(file, code string) ([]*router.CIDR, error) {
index := file + ":" + code
if IPCache[index] == nil {
bs, err := loadFile(file)
if err != nil {
return nil, errors.New("failed to load file: ", file).Base(err)
}
bs = find(bs, []byte(code))
if bs == nil {
return nil, errors.New("code not found in ", file, ": ", code)
}
var geoip router.GeoIP
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
// dont pass code becuase we have country code in top level router.GeoIP
geoip = router.GeoIP{Cidr: []*router.CIDR{}}
} else {
bs, err := loadFile(file)
if err != nil {
return nil, errors.New("failed to load file: ", file).Base(err)
}
bs = filesystem.Find(bs, []byte(code))
if bs == nil {
return nil, errors.New("code not found in ", file, ": ", code)
}
if err := proto.Unmarshal(bs, &geoip); err != nil {
return nil, errors.New("error unmarshal IP in ", file, ": ", code).Base(err)
}
if err := proto.Unmarshal(bs, &geoip); err != nil {
return nil, errors.New("error unmarshal IP in ", file, ": ", code).Base(err)
}
defer runtime.GC() // or debug.FreeOSMemory()
return geoip.Cidr, nil // do not cache geoip
@@ -231,28 +225,18 @@ func loadIP(file, code string) ([]*router.CIDR, error) {
func loadSite(file, code string) ([]*router.Domain, error) {
index := file + ":" + code
if SiteCache[index] == nil {
var geosite router.GeoSite
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
// pass file:code so can build optimized matcher later
domain := router.Domain{Value: file + "_" + code}
geosite = router.GeoSite{Domain: []*router.Domain{&domain}}
} else {
bs, err := loadFile(file)
if err != nil {
return nil, errors.New("failed to load file: ", file).Base(err)
}
bs = filesystem.Find(bs, []byte(code))
if bs == nil {
return nil, errors.New("list not found in ", file, ": ", code)
}
if err := proto.Unmarshal(bs, &geosite); err != nil {
return nil, errors.New("error unmarshal Site in ", file, ": ", code).Base(err)
}
bs, err := loadFile(file)
if err != nil {
return nil, errors.New("failed to load file: ", file).Base(err)
}
bs = find(bs, []byte(code))
if bs == nil {
return nil, errors.New("list not found in ", file, ": ", code)
}
var geosite router.GeoSite
if err := proto.Unmarshal(bs, &geosite); err != nil {
return nil, errors.New("error unmarshal Site in ", file, ": ", code).Base(err)
}
defer runtime.GC() // or debug.FreeOSMemory()
return geosite.Domain, nil // do not cache geosite
SiteCache[index] = &geosite
@@ -260,13 +244,105 @@ func loadSite(file, code string) ([]*router.Domain, error) {
return SiteCache[index].Domain, nil
}
func DecodeVarint(buf []byte) (x uint64, n int) {
for shift := uint(0); shift < 64; shift += 7 {
if n >= len(buf) {
return 0, 0
}
b := uint64(buf[n])
n++
x |= (b & 0x7F) << shift
if (b & 0x80) == 0 {
return x, n
}
}
// The number is too large to represent in a 64-bit value.
return 0, 0
}
func find(data, code []byte) []byte {
codeL := len(code)
if codeL == 0 {
return nil
}
for {
dataL := len(data)
if dataL < 2 {
return nil
}
x, y := DecodeVarint(data[1:])
if x == 0 && y == 0 {
return nil
}
headL, bodyL := 1+y, int(x)
dataL -= headL
if dataL < bodyL {
return nil
}
data = data[headL:]
if int(data[1]) == codeL {
for i := 0; i < codeL && data[2+i] == code[i]; i++ {
if i+1 == codeL {
return data[:bodyL]
}
}
}
if dataL == bodyL {
return nil
}
data = data[bodyL:]
}
}
type AttributeMatcher interface {
Match(*router.Domain) bool
}
type BooleanMatcher string
func (m BooleanMatcher) Match(domain *router.Domain) bool {
for _, attr := range domain.Attribute {
if attr.Key == string(m) {
return true
}
}
return false
}
type AttributeList struct {
matcher []AttributeMatcher
}
func (al *AttributeList) Match(domain *router.Domain) bool {
for _, matcher := range al.matcher {
if !matcher.Match(domain) {
return false
}
}
return true
}
func (al *AttributeList) IsEmpty() bool {
return len(al.matcher) == 0
}
func parseAttrs(attrs []string) *AttributeList {
al := new(AttributeList)
for _, attr := range attrs {
lc := strings.ToLower(attr)
al.matcher = append(al.matcher, BooleanMatcher(lc))
}
return al
}
func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, error) {
parts := strings.Split(siteWithAttr, "@")
if len(parts) == 0 {
return nil, errors.New("empty site")
}
country := strings.ToUpper(parts[0])
attrs := router.ParseAttrs(parts[1:])
attrs := parseAttrs(parts[1:])
domains, err := loadSite(file, country)
if err != nil {
return nil, err
@@ -276,11 +352,6 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
return domains, nil
}
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
domains[0].Value = domains[0].Value + "_" + strings.Join(parts[1:], ",")
return domains, nil
}
filteredDomains := make([]*router.Domain, 0, len(domains))
for _, domain := range domains {
if attrs.Match(domain) {
@@ -291,7 +362,7 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
return filteredDomains, nil
}
func ParseDomainRule(domain string) ([]*router.Domain, error) {
func parseDomainRule(domain string) ([]*router.Domain, error) {
if strings.HasPrefix(domain, "geosite:") {
country := strings.ToUpper(domain[8:])
domains, err := loadGeositeWithAttr("geosite.dat", country)
@@ -489,7 +560,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
if rawFieldRule.Domain != nil {
for _, domain := range *rawFieldRule.Domain {
rules, err := ParseDomainRule(domain)
rules, err := parseDomainRule(domain)
if err != nil {
return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
}
@@ -499,7 +570,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
if rawFieldRule.Domains != nil {
for _, domain := range *rawFieldRule.Domains {
rules, err := ParseDomainRule(domain)
rules, err := parseDomainRule(domain)
if err != nil {
return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
}

View File

@@ -453,7 +453,7 @@ func (c *HysteriaConfig) Build() (proto.Message, error) {
}
config := &hysteria.Config{}
config.Version = int32(c.Version)
config.Version = c.Version
config.Auth = c.Auth
config.Up = up
config.Down = down

View File

@@ -24,7 +24,8 @@ const (
type ClientConfig struct {
state protoimpl.MessageState `protogen:"open.v1"`
Server *protocol.ServerEndpoint `protobuf:"bytes,1,opt,name=server,proto3" json:"server,omitempty"`
Version int32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"`
Server *protocol.ServerEndpoint `protobuf:"bytes,2,opt,name=server,proto3" json:"server,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -59,6 +60,13 @@ func (*ClientConfig) Descriptor() ([]byte, []int) {
return file_proxy_hysteria_config_proto_rawDescGZIP(), []int{0}
}
func (x *ClientConfig) GetVersion() int32 {
if x != nil {
return x.Version
}
return 0
}
func (x *ClientConfig) GetServer() *protocol.ServerEndpoint {
if x != nil {
return x.Server
@@ -70,9 +78,10 @@ var File_proxy_hysteria_config_proto protoreflect.FileDescriptor
const file_proxy_hysteria_config_proto_rawDesc = "" +
"\n" +
"\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\"L\n" +
"\fClientConfig\x12<\n" +
"\x06server\x18\x01 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06serverB[\n" +
"\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\"f\n" +
"\fClientConfig\x12\x18\n" +
"\aversion\x18\x01 \x01(\x05R\aversion\x12<\n" +
"\x06server\x18\x02 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06serverB[\n" +
"\x17com.xray.proxy.hysteriaP\x01Z(github.com/xtls/xray-core/proxy/hysteria\xaa\x02\x13Xray.Proxy.Hysteriab\x06proto3"
var (

View File

@@ -9,5 +9,6 @@ option java_multiple_files = true;
import "common/protocol/server_spec.proto";
message ClientConfig {
xray.common.protocol.ServerEndpoint server = 1;
int32 version = 1;
xray.common.protocol.ServerEndpoint server = 2;
}

View File

@@ -195,6 +195,14 @@ func (c *PacketConnWrapper) SetWriteDeadline(t time.Time) error {
return c.Conn.SetWriteDeadline(t)
}
func (c *PacketConnWrapper) SyscallConn() (syscall.RawConn, error) {
sc, ok := c.Conn.(syscall.Conn)
if !ok {
return nil, syscall.EINVAL
}
return sc.SyscallConn()
}
type SystemDialerAdapter interface {
Dial(network string, address string) (net.Conn, error)
}

View File

@@ -281,35 +281,41 @@ func (c *Config) parseServerName() string {
}
func (r *RandCarrier) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) (err error) {
// extract x509 certificates from rawCerts(verifiedChains will be nil if InsecureSkipVerify is true)
// extract x509 certificates from rawCerts (verifiedChains will be nil if InsecureSkipVerify is true)
certs := make([]*x509.Certificate, len(rawCerts))
for i, asn1Data := range rawCerts {
certs[i], _ = x509.ParseCertificate(asn1Data)
}
if len(certs) == 0 {
return errors.New("unexpected certs")
}
if certs[0].IsCA {
slices.Reverse(certs)
}
// directly return success if pinned cert is leaf
// or add the CA to RootCAs if pinned cert is CA(and can be used in VerifyPeerCertInNames for Self signed CA)
RootCAs := r.RootCAs
// or replace RootCAs if pinned cert is CA (and can be used in VerifyPeerCertInNames)
CAs := r.RootCAs
var verifyResult verifyResult
var verifiedCert *x509.Certificate
if r.PinnedPeerCertSha256 != nil {
verifyResult, verifiedCert = verifyChain(certs, r.PinnedPeerCertSha256)
switch verifyResult {
case certNotFound:
return errors.New("peer cert is unrecognized")
return errors.New("peer cert is unrecognized (againsts pinnedPeerCertSha256)")
case foundLeaf:
return nil
case foundCA:
RootCAs = x509.NewCertPool()
RootCAs.AddCert(verifiedCert)
CAs = x509.NewCertPool()
CAs.AddCert(verifiedCert)
default:
panic("impossible PinnedPeerCertificateSha256 verify result")
panic("impossible pinnedPeerCertSha256 verify result")
}
}
if len(r.VerifyPeerCertInNames) > 0 {
if r.VerifyPeerCertInNames != nil { // RAW's Dial() may make it empty but not nil
opts := x509.VerifyOptions{
Roots: RootCAs,
Roots: CAs,
CurrentTime: time.Now(),
Intermediates: x509.NewCertPool(),
}
@@ -321,9 +327,15 @@ func (r *RandCarrier) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509
return nil
}
}
} else if len(verifiedChains) == 0 && verifyResult == foundCA { // if found ca and verifiedChains is empty, we need to verify here
if verifyResult == foundCA {
errors.New("peer cert is invalid (againsts pinned CA and verifyPeerCertInNames)")
}
return errors.New("peer cert is invalid (againsts root CAs and verifyPeerCertInNames)")
}
if verifyResult == foundCA { // if found CA, we need to verify here
opts := x509.VerifyOptions{
Roots: RootCAs,
Roots: CAs,
CurrentTime: time.Now(),
Intermediates: x509.NewCertPool(),
DNSName: r.Config.ServerName,
@@ -334,8 +346,10 @@ func (r *RandCarrier) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509
if _, err := certs[0].Verify(opts); err == nil {
return nil
}
return errors.New("peer cert is invalid (againsts pinned CA and serverName)")
}
return nil
return nil // len(r.PinnedPeerCertSha256)==nil && len(r.VerifyPeerCertInNames)==nil
}
type RandCarrier struct {
@@ -386,6 +400,11 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
} else {
randCarrier.VerifyPeerCertInNames = nil
}
if len(c.PinnedPeerCertSha256) > 0 {
config.InsecureSkipVerify = true
} else {
randCarrier.PinnedPeerCertSha256 = nil
}
for _, opt := range opts {
opt(config)
@@ -540,17 +559,16 @@ const (
foundCA
)
func verifyChain(certs []*x509.Certificate, PinnedPeerCertificateSha256 [][]byte) (verifyResult, *x509.Certificate) {
func verifyChain(certs []*x509.Certificate, pinnedPeerCertSha256 [][]byte) (verifyResult, *x509.Certificate) {
for _, cert := range certs {
certHash := GenerateCertHash(cert)
for _, c := range PinnedPeerCertificateSha256 {
for _, c := range pinnedPeerCertSha256 {
if hmac.Equal(certHash, c) {
if cert.IsCA {
return foundCA, cert
} else {
return foundLeaf, cert
}
}
}
}

View File

@@ -1,12 +1,15 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"testing"
"github.com/stretchr/testify/assert"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/protocol/tls/cert"
)
func TestCalculateCertHash(t *testing.T) {
@@ -95,3 +98,60 @@ uI6HqHFD3iEct8fBkYfQiwH2e1eu9OwgujiWHsutyK8VvzVB3/YnhQ/TzciRjPqz
assert.Equal(t, fingerprint, hash)
})
}
func TestVerifyPeerLeafCert(t *testing.T) {
leafCert := cert.MustGenerate(nil, cert.DNSNames("example.com"))
leaf := common.Must2(x509.ParseCertificate(leafCert.Certificate))
caHash := GenerateCertHash(leafCert.Certificate)
r := &RandCarrier{
Config: &tls.Config{
ServerName: "example.com",
},
PinnedPeerCertSha256: [][]byte{caHash},
}
rawCerts := [][]byte{leaf.Raw}
err := r.verifyPeerCert(rawCerts, nil)
if err != nil {
t.Fatal("expected to verify leaf cert signed by pinned CA, but got error:", err)
}
// make the pinned hash incorrect
r.PinnedPeerCertSha256[0][0] += 1
err = r.verifyPeerCert(rawCerts, nil)
if err == nil {
t.Fatal("expected to fail verifying leaf cert with incorrect pinned CA hash, but got no error")
}
}
func TestVerifyPeerCACert(t *testing.T) {
caCert := cert.MustGenerate(nil, cert.Authority(true), cert.KeyUsage(x509.KeyUsageCertSign))
ca := common.Must2(x509.ParseCertificate(caCert.Certificate))
leafCert := cert.MustGenerate(caCert, cert.DNSNames("example.com"))
leaf := common.Must2(x509.ParseCertificate(leafCert.Certificate))
caHash := GenerateCertHash(ca)
r := &RandCarrier{
Config: &tls.Config{
ServerName: "example.com",
},
PinnedPeerCertSha256: [][]byte{caHash},
}
rawCerts := [][]byte{leaf.Raw, ca.Raw}
err := r.verifyPeerCert(rawCerts, nil)
if err != nil {
t.Fatal("expected to verify leaf cert signed by pinned CA, but got error:", err)
}
// make the pinned hash incorrect
r.PinnedPeerCertSha256[0][0] += 1
err = r.verifyPeerCert(rawCerts, nil)
if err == nil {
t.Fatal("expected to fail verifying leaf cert with incorrect pinned CA hash, but got no error")
}
}