Compare commits

..

4 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
447b4438ee Restore inbound and content tag setting before dispatch
As identified in issue #4760, PR #4030 commented out lines that set
inbound and content tags from routing info before dispatch. This broke
domain-based routing for WireGuard connections.

The fix adds back these lines (with mutex-protected access) positioned
right before the Dispatch call, ensuring routing configuration is
properly passed for domain-based routing rules to work.

This addresses the feedback that uncommenting those lines fixes routing.

Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
2026-01-11 09:13:06 +00:00
copilot-swe-agent[bot]
1f925778ed Complete WireGuard domain routing fix with tests and review
Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
2026-01-10 22:18:32 +00:00
copilot-swe-agent[bot]
06dee57f8a Add mutex protection for WireGuard routing info
Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
2026-01-10 22:14:55 +00:00
copilot-swe-agent[bot]
7f8928272b Initial plan 2026-01-10 22:04:22 +00:00
7 changed files with 66 additions and 318 deletions

View File

@@ -9,7 +9,6 @@ import (
"github.com/xtls/xray-core/common/mux"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/serial"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/stats"
@@ -53,20 +52,6 @@ type AlwaysOnInboundHandler struct {
}
func NewAlwaysOnInboundHandler(ctx context.Context, tag string, receiverConfig *proxyman.ReceiverConfig, proxyConfig interface{}) (*AlwaysOnInboundHandler, error) {
// Set tag and sniffing config in context before creating proxy
// This allows proxies like TUN to access these settings
ctx = session.ContextWithInbound(ctx, &session.Inbound{Tag: tag})
if receiverConfig.SniffingSettings != nil {
ctx = session.ContextWithContent(ctx, &session.Content{
SniffingRequest: session.SniffingRequest{
Enabled: receiverConfig.SniffingSettings.Enabled,
OverrideDestinationForProtocol: receiverConfig.SniffingSettings.DestinationOverride,
ExcludeForDomain: receiverConfig.SniffingSettings.DomainsExcluded,
MetadataOnly: receiverConfig.SniffingSettings.MetadataOnly,
RouteOnly: receiverConfig.SniffingSettings.RouteOnly,
},
})
}
rawProxy, err := common.CreateObject(ctx, proxyConfig)
if err != nil {
return nil, err

View File

@@ -52,7 +52,7 @@ func GetGlobalID(ctx context.Context) (globalID [8]byte) {
return
}
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun") {
(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") {
h := blake3.New(8, BaseKey)
h.Write([]byte(inbound.Source.String()))
copy(globalID[:], h.Sum(nil))

View File

@@ -130,10 +130,7 @@ type InboundDetourConfig struct {
func (c *InboundDetourConfig) Build() (*core.InboundHandlerConfig, error) {
receiverSettings := &proxyman.ReceiverConfig{}
// TUN inbound doesn't need port configuration as it uses network interface instead
if strings.ToLower(c.Protocol) == "tun" {
// Skip port validation for TUN
} else if c.ListenOn == nil {
if c.ListenOn == nil {
// Listen on anyip, must set PortList
if c.PortList == nil {
return nil, errors.New("Listen on AnyIP but no Port(s) set in InboundDetour.")

View File

@@ -7,7 +7,6 @@ import (
"github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
@@ -20,13 +19,11 @@ import (
// Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
type Handler struct {
ctx context.Context
config *Config
stack Stack
policyManager policy.Manager
dispatcher routing.Dispatcher
tag string
sniffingRequest session.SniffingRequest
ctx context.Context
config *Config
stack Stack
policyManager policy.Manager
dispatcher routing.Dispatcher
}
// ConnectionHandler interface with the only method that stack is going to push new connections to
@@ -46,14 +43,6 @@ func (t *Handler) policy() policy.Session {
func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error {
var err error
// Retrieve tag and sniffing config from context (set by AlwaysOnInboundHandler)
if inbound := session.InboundFromContext(ctx); inbound != nil {
t.tag = inbound.Tag
}
if content := session.ContentFromContext(ctx); content != nil {
t.sniffingRequest = content.SniffingRequest
}
t.ctx = core.ToBackgroundDetachedContext(ctx)
t.policyManager = pm
t.dispatcher = dispatcher
@@ -104,38 +93,29 @@ func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routin
func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) {
sid := session.NewID()
ctx := c.ContextWithID(t.ctx, sid)
errors.LogInfo(ctx, "processing connection from: ", conn.RemoteAddr())
inbound := session.Inbound{
Name: "tun",
Tag: t.tag,
CanSpliceCopy: 3,
Source: net.DestinationFromAddr(conn.RemoteAddr()),
User: &protocol.MemoryUser{
Level: t.config.UserLevel,
},
inbound := session.Inbound{}
inbound.Name = "tun"
inbound.CanSpliceCopy = 1
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
inbound.User = &protocol.MemoryUser{
Level: t.config.UserLevel,
}
ctx = session.ContextWithInbound(ctx, &inbound)
ctx = session.ContextWithContent(ctx, &session.Content{
SniffingRequest: t.sniffingRequest,
})
ctx = session.SubContextFromMuxInbound(ctx)
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: inbound.Source,
To: destination,
Status: log.AccessAccepted,
Reason: "",
})
errors.LogInfo(ctx, "processing TCP from ", conn.RemoteAddr(), " to ", destination)
link := &transport.Link{
Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)},
Writer: buf.NewWriter(conn),
}
if err := t.dispatcher.DispatchLink(ctx, destination, link); err != nil {
errors.LogError(ctx, errors.New("connection closed").Base(err))
return
}
errors.LogInfo(ctx, "connection completed")
}
// Network implements proxy.Inbound

View File

@@ -6,7 +6,6 @@ import (
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -101,35 +100,32 @@ func (t *stackGVisor) Start() error {
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
// Use custom UDP packet handler, instead of strict gVisor forwarder, for FullCone NAT support
udpForwarder := newUdpConnectionHandler(t.ctx, t.handler, func(p []byte) {
// extract network protocol from the packet
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(p) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
default:
// discard packet with unknown network version
return
}
udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) {
go func(r *udp.ForwarderRequest) {
var wq waiter.Queue
var id = r.ID()
ipStack.WriteRawPacket(defaultNIC, networkProtocol, buffer.MakeWithData(p))
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
data := pkt.Data().AsRange().ToSlice()
if len(data) == 0 {
return false
}
// source/destination of the packet we process as incoming, on gVisor side are Remote/Local
// in other terms, src is the side behind tun, dst is the side behind gVisor
// this function handle packets passing from the tun to the gVisor, therefore the src/dst assignement
src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
ep, err := r.CreateEndpoint(&wq)
if err != nil {
errors.LogError(t.ctx, err.String())
return
}
return udpForwarder.HandlePacket(src, dst, data)
options := ep.SocketOptions()
options.SetReuseAddress(true)
options.SetReusePort(true)
t.handler.HandleConnection(
gonet.NewUDPConn(&wq, ep),
// local address on the gVisor side is connection destination
net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)),
)
// close the socket
ep.Close()
}(r)
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
t.stack = ipStack
t.endpoint = linkEndpoint

View File

@@ -1,228 +0,0 @@
package tun
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/pipe"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// udp connection abstraction
type udpConn struct {
lastActive atomic.Int64
reader buf.Reader
writer buf.Writer
done *done.Instance
cancel context.CancelFunc
}
// sub-handler specifically for udp connections under main handler
type udpConnectionHandler struct {
sync.Mutex
ctx context.Context
handler *Handler
udpConns map[net.Destination]*udpConn
udpChecker *task.Periodic
writePacket func(p []byte)
}
func newUdpConnectionHandler(ctx context.Context, h *Handler, writePacket func(p []byte)) *udpConnectionHandler {
handler := &udpConnectionHandler{
ctx: ctx,
handler: h,
udpConns: make(map[net.Destination]*udpConn),
writePacket: writePacket,
}
handler.udpChecker = &task.Periodic{Interval: time.Minute, Execute: handler.cleanupUDP}
handler.udpChecker.Start()
return handler
}
func (u *udpConnectionHandler) cleanupUDP() error {
u.Lock()
defer u.Unlock()
if len(u.udpConns) == 0 {
return errors.New("no connections")
}
now := time.Now().Unix()
for src, conn := range u.udpConns {
if now-conn.lastActive.Load() > 300 {
conn.cancel()
common.Must(conn.done.Close())
common.Must(common.Close(conn.writer))
delete(u.udpConns, src)
}
}
return nil
}
// HandlePacket handles UDP packets coming from tun, to forward to the dispatcher
// this custom handler support FullCone NAT of returning packets, binding connection only by the source port
func (u *udpConnectionHandler) HandlePacket(src net.Destination, dst net.Destination, data []byte) bool {
u.Lock()
conn, found := u.udpConns[src]
if !found {
reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
conn = &udpConn{reader: reader, writer: writer, done: done.New()}
u.udpConns[src] = conn
u.Unlock()
go func() {
ctx, cancel := context.WithCancel(u.ctx)
conn.cancel = cancel
defer func() {
cancel()
u.Lock()
delete(u.udpConns, src)
u.Unlock()
common.Must(conn.done.Close())
common.Must(common.Close(conn.writer))
}()
inbound := &session.Inbound{
Name: "tun",
Tag: u.handler.tag,
Source: src,
CanSpliceCopy: 3,
User: &protocol.MemoryUser{Level: u.handler.config.UserLevel},
}
ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound)
ctx = session.ContextWithContent(ctx, &session.Content{
SniffingRequest: u.handler.sniffingRequest,
})
ctx = session.SubContextFromMuxInbound(ctx)
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: src,
To: dst,
Status: log.AccessAccepted,
Reason: "",
})
errors.LogInfo(ctx, "processing UDP from ", src, " to ", dst)
link := &transport.Link{
Reader: &buf.TimeoutWrapperReader{Reader: conn.reader},
// reverse source and destination, indicating the packets to write are going in the other
// direction (written back to tun) and should have reversed addressing
Writer: &udpWriter{handler: u, src: dst, dst: src},
}
_ = u.handler.dispatcher.DispatchLink(ctx, dst, link)
}()
} else {
conn.lastActive.Store(time.Now().Unix())
u.Unlock()
}
b := buf.New()
b.Write(data)
b.UDP = &dst
conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
return true
}
type udpWriter struct {
handler *udpConnectionHandler
// address in the side of stack, where packet will be coming from
src net.Destination
// address on the side of tun, where packet will be destined to
dst net.Destination
}
func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
for _, b := range mb {
// use captured in the dispatched packet source address b.UDP as source, if available,
// otherwise use captured in the writer source w.src
srcAddr := w.src
if b.UDP != nil {
srcAddr = *b.UDP
}
// validate address family matches
if srcAddr.Address.Family() != w.src.Address.Family() {
errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family())
b.Release()
continue
}
payload := b.Bytes()
udpLen := header.UDPMinimumSize + len(payload)
srcIP := tcpip.AddrFromSlice(srcAddr.Address.IP())
dstIP := tcpip.AddrFromSlice(w.dst.Address.IP())
// build packet with appropriate IP header size
isIPv4 := srcAddr.Address.Family().IsIPv4()
ipHdrSize := header.IPv6MinimumSize
if isIPv4 {
ipHdrSize = header.IPv4MinimumSize
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
Payload: buffer.MakeWithData(payload),
})
// Build UDP header
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
udpHdr.Encode(&header.UDPFields{
SrcPort: uint16(srcAddr.Port),
DstPort: uint16(w.dst.Port),
Length: uint16(udpLen),
})
// Calculate and set UDP checksum
xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
// Build IP header
if isIPv4 {
ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
ipHdr.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + udpLen),
TTL: 64,
Protocol: uint8(header.UDPProtocolNumber),
SrcAddr: srcIP,
DstAddr: dstIP,
})
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
} else {
ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ipHdr.Encode(&header.IPv6Fields{
PayloadLength: uint16(udpLen),
TransportProtocol: header.UDPProtocolNumber,
HopLimit: 64,
SrcAddr: srcIP,
DstAddr: dstIP,
})
}
// Write raw packet to network stack
views := pkt.AsSlices()
var data []byte
for _, view := range views {
data = append(data, view...)
}
w.handler.writePacket(data)
pkt.DecRef()
b.Release()
}
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
goerrors "errors"
"io"
"sync"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
@@ -26,6 +27,7 @@ var nullDestination = net.TCPDestination(net.AnyIP, 0)
type Server struct {
bindServer *netBindServer
infoMu sync.RWMutex
info routingInfo
policyManager policy.Manager
}
@@ -78,12 +80,14 @@ func (*Server) Network() []net.Network {
// Process implements proxy.Inbound.
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
s.infoMu.Lock()
s.info = routingInfo{
ctx: ctx,
dispatcher: dispatcher,
inboundTag: session.InboundFromContext(ctx),
contentTag: session.ContentFromContext(ctx),
}
s.infoMu.Unlock()
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
if err != nil {
@@ -120,18 +124,23 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
}
func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
if s.info.dispatcher == nil {
errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
// Make a thread-safe copy of routing info
s.infoMu.RLock()
info := s.info
s.infoMu.RUnlock()
if info.dispatcher == nil {
errors.LogError(info.ctx, "unexpected: dispatcher == nil")
return
}
defer conn.Close()
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(info.ctx))
sid := session.NewID()
ctx = c.ContextWithID(ctx, sid)
inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
if s.info.inboundTag != nil {
inbound = *s.info.inboundTag
if info.inboundTag != nil {
inbound = *info.inboundTag
}
inbound.Name = "wireguard"
inbound.CanSpliceCopy = 3
@@ -141,8 +150,8 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
// Currently we have no way to link to the original source address
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
ctx = session.ContextWithInbound(ctx, &inbound)
if s.info.contentTag != nil {
ctx = session.ContextWithContent(ctx, s.info.contentTag)
if info.contentTag != nil {
ctx = session.ContextWithContent(ctx, info.contentTag)
}
ctx = session.SubContextFromMuxInbound(ctx)
@@ -156,7 +165,16 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
Reason: "",
})
link, err := s.info.dispatcher.Dispatch(ctx, dest)
// Set inbound and content tags from routing info for proper routing
// These were commented out in PR #4030 but are needed for domain-based routing
if info.inboundTag != nil {
ctx = session.ContextWithInbound(ctx, info.inboundTag)
}
if info.contentTag != nil {
ctx = session.ContextWithContent(ctx, info.contentTag)
}
link, err := info.dispatcher.Dispatch(ctx, dest)
if err != nil {
errors.LogErrorInner(ctx, err, "dispatch connection")
}