mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-01-13 14:17:09 +08:00
Compare commits
13 Commits
main
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31d10f3544 | ||
|
|
1ad1608581 | ||
|
|
c00c697b65 | ||
|
|
4e0a87faf4 | ||
|
|
2d37e84d4d | ||
|
|
47a1e042e4 | ||
|
|
cc36c1b5bf | ||
|
|
ea3badc641 | ||
|
|
41050594e5 | ||
|
|
ecef77ff48 | ||
|
|
52f7f3d174 | ||
|
|
385867e82b | ||
|
|
a99fe66467 |
@@ -160,7 +160,7 @@ func (s *ClassicNameServer) getCacheController() *CacheController {
|
||||
}
|
||||
|
||||
// sendQuery implements CachedNameserver.
|
||||
func (s *ClassicNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) {
|
||||
func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, fqdn string, option dns_feature.IPOption) {
|
||||
errors.LogInfo(ctx, s.Name(), " querying DNS for: ", fqdn)
|
||||
|
||||
reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
|
||||
@@ -171,14 +171,7 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<
|
||||
ctx: ctx,
|
||||
}
|
||||
s.addPendingRequest(udpReq)
|
||||
b, err := dns.PackMessage(req.msg)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to pack dns query")
|
||||
if noResponseErrCh != nil {
|
||||
noResponseErrCh <- err
|
||||
}
|
||||
return
|
||||
}
|
||||
b, _ := dns.PackMessage(req.msg)
|
||||
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
|
||||
b.UDP = ©Dest
|
||||
s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/xtls/xray-core/common/mux"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/serial"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/features/stats"
|
||||
@@ -53,20 +52,6 @@ type AlwaysOnInboundHandler struct {
|
||||
}
|
||||
|
||||
func NewAlwaysOnInboundHandler(ctx context.Context, tag string, receiverConfig *proxyman.ReceiverConfig, proxyConfig interface{}) (*AlwaysOnInboundHandler, error) {
|
||||
// Set tag and sniffing config in context before creating proxy
|
||||
// This allows proxies like TUN to access these settings
|
||||
ctx = session.ContextWithInbound(ctx, &session.Inbound{Tag: tag})
|
||||
if receiverConfig.SniffingSettings != nil {
|
||||
ctx = session.ContextWithContent(ctx, &session.Content{
|
||||
SniffingRequest: session.SniffingRequest{
|
||||
Enabled: receiverConfig.SniffingSettings.Enabled,
|
||||
OverrideDestinationForProtocol: receiverConfig.SniffingSettings.DestinationOverride,
|
||||
ExcludeForDomain: receiverConfig.SniffingSettings.DomainsExcluded,
|
||||
MetadataOnly: receiverConfig.SniffingSettings.MetadataOnly,
|
||||
RouteOnly: receiverConfig.SniffingSettings.RouteOnly,
|
||||
},
|
||||
})
|
||||
}
|
||||
rawProxy, err := common.CreateObject(ctx, proxyConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -52,7 +52,7 @@ func GetGlobalID(ctx context.Context) (globalID [8]byte) {
|
||||
return
|
||||
}
|
||||
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
|
||||
(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun") {
|
||||
(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") {
|
||||
h := blake3.New(8, BaseKey)
|
||||
h.Write([]byte(inbound.Source.String()))
|
||||
copy(globalID[:], h.Sum(nil))
|
||||
|
||||
4
go.mod
4
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/xtls/xray-core
|
||||
|
||||
go 1.25.5
|
||||
go 1.25
|
||||
|
||||
require (
|
||||
github.com/cloudflare/circl v1.6.2
|
||||
@@ -29,7 +29,7 @@ require (
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
google.golang.org/grpc v1.78.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gvisor.dev/gvisor v0.0.0-20260109181451-4be7c433dae2
|
||||
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5
|
||||
h12.io/socks v1.0.3
|
||||
lukechampine.com/blake3 v1.4.1
|
||||
)
|
||||
|
||||
4
go.sum
4
go.sum
@@ -156,8 +156,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
|
||||
gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gvisor.dev/gvisor v0.0.0-20260109181451-4be7c433dae2 h1:fr6L00yGG2RP5NMea6njWpdC+bm+cMdFClrSpaicp1c=
|
||||
gvisor.dev/gvisor v0.0.0-20260109181451-4be7c433dae2/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
||||
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk=
|
||||
gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
h12.io/socks v1.0.3 h1:Ka3qaQewws4j4/eDQnOdpr4wXsC//dXtWvftlIcCQUo=
|
||||
h12.io/socks v1.0.3/go.mod h1:AIhxy1jOId/XCz9BO+EIgNL2rQiPTBNnOfnVnQ+3Eck=
|
||||
lukechampine.com/blake3 v1.4.1 h1:I3Smz7gso8w4/TunLKec6K2fn+kyKtDxr/xcQEN84Wg=
|
||||
|
||||
@@ -130,10 +130,7 @@ type InboundDetourConfig struct {
|
||||
func (c *InboundDetourConfig) Build() (*core.InboundHandlerConfig, error) {
|
||||
receiverSettings := &proxyman.ReceiverConfig{}
|
||||
|
||||
// TUN inbound doesn't need port configuration as it uses network interface instead
|
||||
if strings.ToLower(c.Protocol) == "tun" {
|
||||
// Skip port validation for TUN
|
||||
} else if c.ListenOn == nil {
|
||||
if c.ListenOn == nil {
|
||||
// Listen on anyip, must set PortList
|
||||
if c.PortList == nil {
|
||||
return nil, errors.New("Listen on AnyIP but no Port(s) set in InboundDetour.")
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
c "github.com/xtls/xray-core/common/ctx"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
@@ -20,13 +19,11 @@ import (
|
||||
|
||||
// Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
|
||||
type Handler struct {
|
||||
ctx context.Context
|
||||
config *Config
|
||||
stack Stack
|
||||
policyManager policy.Manager
|
||||
dispatcher routing.Dispatcher
|
||||
tag string
|
||||
sniffingRequest session.SniffingRequest
|
||||
ctx context.Context
|
||||
config *Config
|
||||
stack Stack
|
||||
policyManager policy.Manager
|
||||
dispatcher routing.Dispatcher
|
||||
}
|
||||
|
||||
// ConnectionHandler interface with the only method that stack is going to push new connections to
|
||||
@@ -46,14 +43,6 @@ func (t *Handler) policy() policy.Session {
|
||||
func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error {
|
||||
var err error
|
||||
|
||||
// Retrieve tag and sniffing config from context (set by AlwaysOnInboundHandler)
|
||||
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
||||
t.tag = inbound.Tag
|
||||
}
|
||||
if content := session.ContentFromContext(ctx); content != nil {
|
||||
t.sniffingRequest = content.SniffingRequest
|
||||
}
|
||||
|
||||
t.ctx = core.ToBackgroundDetachedContext(ctx)
|
||||
t.policyManager = pm
|
||||
t.dispatcher = dispatcher
|
||||
@@ -104,39 +93,29 @@ func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routin
|
||||
func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) {
|
||||
sid := session.NewID()
|
||||
ctx := c.ContextWithID(t.ctx, sid)
|
||||
errors.LogInfo(ctx, "processing connection from: ", conn.RemoteAddr())
|
||||
|
||||
source := net.DestinationFromAddr(conn.RemoteAddr())
|
||||
inbound := session.Inbound{
|
||||
Name: "tun",
|
||||
Tag: t.tag,
|
||||
CanSpliceCopy: 3,
|
||||
Source: source,
|
||||
User: &protocol.MemoryUser{
|
||||
Level: t.config.UserLevel,
|
||||
},
|
||||
inbound := session.Inbound{}
|
||||
inbound.Name = "tun"
|
||||
inbound.CanSpliceCopy = 1
|
||||
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
|
||||
inbound.User = &protocol.MemoryUser{
|
||||
Level: t.config.UserLevel,
|
||||
}
|
||||
|
||||
ctx = session.ContextWithInbound(ctx, &inbound)
|
||||
ctx = session.ContextWithContent(ctx, &session.Content{
|
||||
SniffingRequest: t.sniffingRequest,
|
||||
})
|
||||
ctx = session.SubContextFromMuxInbound(ctx)
|
||||
|
||||
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||
From: inbound.Source,
|
||||
To: destination,
|
||||
Status: log.AccessAccepted,
|
||||
Reason: "",
|
||||
})
|
||||
errors.LogInfo(ctx, "processing from ", source, " to ", destination)
|
||||
|
||||
link := &transport.Link{
|
||||
Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)},
|
||||
Writer: buf.NewWriter(conn),
|
||||
}
|
||||
if err := t.dispatcher.DispatchLink(ctx, destination, link); err != nil {
|
||||
errors.LogError(ctx, errors.New("connection closed").Base(err))
|
||||
return
|
||||
}
|
||||
|
||||
errors.LogInfo(ctx, "connection completed")
|
||||
}
|
||||
|
||||
// Network implements proxy.Inbound
|
||||
|
||||
@@ -6,10 +6,8 @@ import (
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
@@ -102,21 +100,32 @@ func (t *stackGVisor) Start() error {
|
||||
})
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
|
||||
// Use custom UDP packet handler, instead of strict gVisor forwarder, for FullCone NAT support
|
||||
udpForwarder := newUdpConnectionHandler(t.handler.HandleConnection, t.writeRawUDPPacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
data := pkt.Data().AsRange().ToSlice()
|
||||
if len(data) == 0 {
|
||||
return false
|
||||
}
|
||||
// source/destination of the packet we process as incoming, on gVisor side are Remote/Local
|
||||
// in other terms, src is the side behind tun, dst is the side behind gVisor
|
||||
// this function handle packets passing from the tun to the gVisor, therefore the src/dst assignement
|
||||
src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
|
||||
dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
|
||||
udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) {
|
||||
go func(r *udp.ForwarderRequest) {
|
||||
var wq waiter.Queue
|
||||
var id = r.ID()
|
||||
|
||||
return udpForwarder.HandlePacket(src, dst, data)
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
errors.LogError(t.ctx, err.String())
|
||||
return
|
||||
}
|
||||
|
||||
options := ep.SocketOptions()
|
||||
options.SetReuseAddress(true)
|
||||
options.SetReusePort(true)
|
||||
|
||||
t.handler.HandleConnection(
|
||||
gonet.NewUDPConn(&wq, ep),
|
||||
// local address on the gVisor side is connection destination
|
||||
net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)),
|
||||
)
|
||||
|
||||
// close the socket
|
||||
ep.Close()
|
||||
}(r)
|
||||
})
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
|
||||
t.stack = ipStack
|
||||
t.endpoint = linkEndpoint
|
||||
@@ -124,69 +133,6 @@ func (t *stackGVisor) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *stackGVisor) writeRawUDPPacket(payload []byte, src net.Destination, dst net.Destination) error {
|
||||
udpLen := header.UDPMinimumSize + len(payload)
|
||||
srcIP := tcpip.AddrFromSlice(src.Address.IP())
|
||||
dstIP := tcpip.AddrFromSlice(dst.Address.IP())
|
||||
|
||||
// build packet with appropriate IP header size
|
||||
isIPv4 := dst.Address.Family().IsIPv4()
|
||||
ipHdrSize := header.IPv6MinimumSize
|
||||
ipProtocol := header.IPv6ProtocolNumber
|
||||
if isIPv4 {
|
||||
ipHdrSize = header.IPv4MinimumSize
|
||||
ipProtocol = header.IPv4ProtocolNumber
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
// Build UDP header
|
||||
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: uint16(src.Port),
|
||||
DstPort: uint16(dst.Port),
|
||||
Length: uint16(udpLen),
|
||||
})
|
||||
|
||||
// Calculate and set UDP checksum
|
||||
xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
|
||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
|
||||
|
||||
// Build IP header
|
||||
if isIPv4 {
|
||||
ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
|
||||
ipHdr.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + udpLen),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.UDPProtocolNumber),
|
||||
SrcAddr: srcIP,
|
||||
DstAddr: dstIP,
|
||||
})
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
} else {
|
||||
ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
|
||||
ipHdr.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(udpLen),
|
||||
TransportProtocol: header.UDPProtocolNumber,
|
||||
HopLimit: 64,
|
||||
SrcAddr: srcIP,
|
||||
DstAddr: dstIP,
|
||||
})
|
||||
}
|
||||
|
||||
// dispatch the packet
|
||||
err := t.stack.WriteRawPacket(defaultNIC, ipProtocol, buffer.MakeWithView(pkt.ToView()))
|
||||
if err != nil {
|
||||
return errors.New("failed to write raw udp packet back to stack", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is called by Handler to shut down the stack
|
||||
func (t *stackGVisor) Close() error {
|
||||
if t.stack == nil {
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
)
|
||||
|
||||
// sub-handler specifically for udp connections under main handler
|
||||
type udpConnectionHandler struct {
|
||||
sync.Mutex
|
||||
|
||||
udpConns map[net.Destination]*udpConn
|
||||
|
||||
handleConnection func(conn net.Conn, dest net.Destination)
|
||||
writePacket func(data []byte, src net.Destination, dst net.Destination) error
|
||||
}
|
||||
|
||||
func newUdpConnectionHandler(handleConnection func(conn net.Conn, dest net.Destination), writePacket func(data []byte, src net.Destination, dst net.Destination) error) *udpConnectionHandler {
|
||||
handler := &udpConnectionHandler{
|
||||
udpConns: make(map[net.Destination]*udpConn),
|
||||
handleConnection: handleConnection,
|
||||
writePacket: writePacket,
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
// HandlePacket handles UDP packets coming from tun, to forward to the dispatcher
|
||||
// this custom handler support FullCone NAT of returning packets, binding connection only by the source addr:port
|
||||
func (u *udpConnectionHandler) HandlePacket(src net.Destination, dst net.Destination, data []byte) bool {
|
||||
u.Lock()
|
||||
conn, found := u.udpConns[src]
|
||||
if !found {
|
||||
egress := make(chan []byte, 16)
|
||||
conn = &udpConn{handler: u, egress: egress, src: src, dst: dst}
|
||||
u.udpConns[src] = conn
|
||||
|
||||
go u.handleConnection(conn, dst)
|
||||
}
|
||||
u.Unlock()
|
||||
|
||||
// send packet data to the egress channel, if it has buffer, or discard
|
||||
select {
|
||||
case conn.egress <- data:
|
||||
default:
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *udpConnectionHandler) connectionFinished(src net.Destination) {
|
||||
u.Lock()
|
||||
conn, found := u.udpConns[src]
|
||||
if found {
|
||||
delete(u.udpConns, src)
|
||||
close(conn.egress)
|
||||
}
|
||||
u.Unlock()
|
||||
}
|
||||
|
||||
// udp connection abstraction
|
||||
type udpConn struct {
|
||||
net.Conn
|
||||
buf.Writer
|
||||
|
||||
handler *udpConnectionHandler
|
||||
|
||||
egress chan []byte
|
||||
src net.Destination
|
||||
dst net.Destination
|
||||
}
|
||||
|
||||
// Read packets from the connection
|
||||
func (c *udpConn) Read(p []byte) (int, error) {
|
||||
data, ok := <-c.egress
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n := copy(p, data)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Write returning packets back
|
||||
func (c *udpConn) Write(p []byte) (int, error) {
|
||||
// sending packets back mean sending payload with source/destination reversed
|
||||
err := c.handler.writePacket(p, c.dst, c.src)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *udpConn) Close() error {
|
||||
c.handler.connectionFinished(c.src)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpConn) LocalAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: c.dst.Address.IP(), Port: int(c.dst.Port.Value())}
|
||||
}
|
||||
|
||||
func (c *udpConn) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: c.src.Address.IP(), Port: int(c.src.Port.Value())}
|
||||
}
|
||||
|
||||
// Write returning packets back
|
||||
func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
for _, b := range mb {
|
||||
dst := c.dst
|
||||
if b.UDP != nil {
|
||||
dst = *b.UDP
|
||||
}
|
||||
|
||||
// validate address family matches between buffer packet and the connection
|
||||
if dst.Address.Family() != c.dst.Address.Family() {
|
||||
continue
|
||||
}
|
||||
|
||||
// sending packets back mean sending payload with source/destination reversed
|
||||
err := c.handler.writePacket(b.Bytes(), dst, c.src)
|
||||
if err != nil {
|
||||
// udp doesn't guarantee delivery, so in any failure we just continue to the next packet
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -124,6 +124,26 @@ type netBindClient struct {
|
||||
ctx context.Context
|
||||
dialer internet.Dialer
|
||||
reserved []byte
|
||||
|
||||
// Track all peer connections for unified reading
|
||||
connMutex sync.RWMutex
|
||||
conns map[*netEndpoint]net.Conn
|
||||
dataChan chan *receivedData
|
||||
closeChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
const (
|
||||
// Buffer size for dataChan - allows some buffering of received packets
|
||||
// while dispatcher matches them with read requests
|
||||
dataChannelBufferSize = 100
|
||||
)
|
||||
|
||||
type receivedData struct {
|
||||
data []byte
|
||||
n int
|
||||
endpoint *netEndpoint
|
||||
err error
|
||||
}
|
||||
|
||||
func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
||||
@@ -133,34 +153,114 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
||||
}
|
||||
endpoint.conn = c
|
||||
|
||||
go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
|
||||
// Initialize channels on first connection
|
||||
bind.connMutex.Lock()
|
||||
if bind.conns == nil {
|
||||
bind.conns = make(map[*netEndpoint]net.Conn)
|
||||
bind.dataChan = make(chan *receivedData, dataChannelBufferSize)
|
||||
bind.closeChan = make(chan struct{})
|
||||
|
||||
// Start unified reader dispatcher
|
||||
go bind.unifiedReader()
|
||||
}
|
||||
bind.conns[endpoint] = c
|
||||
bind.connMutex.Unlock()
|
||||
|
||||
// Start a reader goroutine for this specific connection
|
||||
go func(conn net.Conn, endpoint *netEndpoint) {
|
||||
const maxPacketSize = 1500
|
||||
for {
|
||||
v, ok := <-readQueue
|
||||
if !ok {
|
||||
select {
|
||||
case <-bind.closeChan:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
buf := make([]byte, maxPacketSize)
|
||||
n, err := conn.Read(buf)
|
||||
|
||||
// Send only the valid data portion to dispatcher
|
||||
dataToSend := buf
|
||||
if n > 0 && n < len(buf) {
|
||||
dataToSend = buf[:n]
|
||||
}
|
||||
|
||||
// Send received data to dispatcher
|
||||
select {
|
||||
case bind.dataChan <- &receivedData{
|
||||
data: dataToSend,
|
||||
n: n,
|
||||
endpoint: endpoint,
|
||||
err: err,
|
||||
}:
|
||||
case <-bind.closeChan:
|
||||
return
|
||||
}
|
||||
i, err := c.Read(v.buff)
|
||||
|
||||
if i > 3 {
|
||||
v.buff[1] = 0
|
||||
v.buff[2] = 0
|
||||
v.buff[3] = 0
|
||||
}
|
||||
|
||||
v.bytes = i
|
||||
v.endpoint = endpoint
|
||||
v.err = err
|
||||
v.waiter.Done()
|
||||
|
||||
if err != nil {
|
||||
bind.connMutex.Lock()
|
||||
delete(bind.conns, endpoint)
|
||||
endpoint.conn = nil
|
||||
bind.connMutex.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}(bind.readQueue, endpoint)
|
||||
}(c, endpoint)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// unifiedReader dispatches received data to waiting read requests
|
||||
func (bind *netBindClient) unifiedReader() {
|
||||
for {
|
||||
select {
|
||||
case data := <-bind.dataChan:
|
||||
// Bounds check to prevent panic
|
||||
if data.n > len(data.data) {
|
||||
data.n = len(data.data)
|
||||
}
|
||||
|
||||
// Wait for a read request with timeout to prevent blocking forever
|
||||
select {
|
||||
case v := <-bind.readQueue:
|
||||
// Copy data to request buffer
|
||||
n := copy(v.buff, data.data[:data.n])
|
||||
|
||||
// Clear reserved bytes if needed
|
||||
if n > 3 {
|
||||
v.buff[1] = 0
|
||||
v.buff[2] = 0
|
||||
v.buff[3] = 0
|
||||
}
|
||||
|
||||
v.bytes = n
|
||||
v.endpoint = data.endpoint
|
||||
v.err = data.err
|
||||
v.waiter.Done()
|
||||
case <-bind.closeChan:
|
||||
return
|
||||
}
|
||||
case <-bind.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements conn.Bind.Close for netBindClient
|
||||
func (bind *netBindClient) Close() error {
|
||||
// Use sync.Once to prevent double-close panic
|
||||
bind.closeOnce.Do(func() {
|
||||
bind.connMutex.Lock()
|
||||
if bind.closeChan != nil {
|
||||
close(bind.closeChan)
|
||||
}
|
||||
bind.connMutex.Unlock()
|
||||
})
|
||||
|
||||
// Call parent Close
|
||||
return bind.netBind.Close()
|
||||
}
|
||||
|
||||
func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
||||
var err error
|
||||
|
||||
|
||||
@@ -114,6 +114,12 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
|
||||
}
|
||||
|
||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||
// Set workers to number of peers if not explicitly configured
|
||||
// This allows concurrent packet reception from multiple peers
|
||||
workers := int(h.conf.NumWorkers)
|
||||
if workers <= 0 && len(h.conf.Peers) > 0 {
|
||||
workers = len(h.conf.Peers)
|
||||
}
|
||||
h.bind = &netBindClient{
|
||||
netBind: netBind{
|
||||
dns: h.dns,
|
||||
@@ -121,9 +127,9 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
|
||||
IPv4Enable: h.hasIPv4,
|
||||
IPv6Enable: h.hasIPv6,
|
||||
},
|
||||
workers: int(h.conf.NumWorkers),
|
||||
workers: workers,
|
||||
},
|
||||
ctx: ctx,
|
||||
ctx: core.ToBackgroundDetachedContext(ctx),
|
||||
dialer: dialer,
|
||||
reserved: h.conf.Reserved,
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
||||
})
|
||||
stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
|
||||
udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) bool {
|
||||
udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) {
|
||||
go func(r *udp.ForwarderRequest) {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
@@ -195,8 +195,6 @@ func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
||||
|
||||
handler(net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewUDPConn(&wq, ep))
|
||||
}(r)
|
||||
|
||||
return true
|
||||
})
|
||||
stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user