From 551c17d9914e7b343eb39325626a51ec9fc3d1d4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 9 Jan 2026 13:11:15 +0000 Subject: [PATCH] Further simplify code structure - Consolidate udpConn methods into single-line implementations - Remove setInactive method, use direct field access - Simplify writeUDPPacket variable declarations - Extract handleUDPConn as separate method for clarity - Reduce cleanupUDPConns redundancy - Inline struct initializations in HandleUDPPacket Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com> --- proxy/tun/handler.go | 162 +++++++++++++------------------------------ 1 file changed, 50 insertions(+), 112 deletions(-) diff --git a/proxy/tun/handler.go b/proxy/tun/handler.go index bd06919c..136f8016 100644 --- a/proxy/tun/handler.go +++ b/proxy/tun/handler.go @@ -47,15 +47,10 @@ type udpConn struct { cancel context.CancelFunc } -func (c *udpConn) setInactive() { - c.inactive = true -} - func (c *udpConn) updateActivity() { atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix()) } -// ReadMultiBuffer implements buf.Reader func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) { mb, err := c.reader.ReadMultiBuffer() if err != nil { @@ -65,14 +60,7 @@ func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) { return mb, nil } -func (c *udpConn) Read(buf []byte) (int, error) { - return 0, errors.New("Read not supported, use ReadMultiBuffer instead") -} - -// Write implements io.Writer func (c *udpConn) Write(data []byte) (int, error) { - // Extract destination from the first buffer if available - // For now, write with empty destination (will be filled by output function) n, err := c.output(data, net.Destination{}) if err == nil { c.updateActivity() @@ -89,25 +77,12 @@ func (c *udpConn) Close() error { return nil } -func (c *udpConn) RemoteAddr() net.Addr { - return c.remote -} - -func (c *udpConn) LocalAddr() net.Addr { - return c.local -} - -func (*udpConn) SetDeadline(time.Time) error { - return nil -} - -func (*udpConn) SetReadDeadline(time.Time) error { - return nil -} - -func (*udpConn) SetWriteDeadline(time.Time) error { - return nil -} +func (c *udpConn) RemoteAddr() net.Addr { return c.remote } +func (c *udpConn) LocalAddr() net.Addr { return c.local } +func (c *udpConn) Read([]byte) (int, error) { return 0, errors.New("not supported") } +func (*udpConn) SetDeadline(time.Time) error { return nil } +func (*udpConn) SetReadDeadline(time.Time) error { return nil } +func (*udpConn) SetWriteDeadline(time.Time) error { return nil } // Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing type Handler struct { @@ -189,21 +164,19 @@ func (t *Handler) removeUDPConn(id udpConnID) { // cleanupUDPConns removes inactive UDP connections func (t *Handler) cleanupUDPConns() error { - nowSec := time.Now().Unix() t.Lock() defer t.Unlock() if len(t.udpConns) == 0 { - return errors.New("UDP connection cleanup stopped: no active connections remaining") + return errors.New("no active connections") } + nowSec := time.Now().Unix() for id, conn := range t.udpConns { - if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 300 { // 5 minutes - if !conn.inactive { - conn.setInactive() - conn.Close() - delete(t.udpConns, id) - } + if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 300 && !conn.inactive { + conn.inactive = true + conn.Close() + delete(t.udpConns, id) } } @@ -212,48 +185,34 @@ func (t *Handler) cleanupUDPConns() error { // writeUDPPacket writes a UDP packet back to the gVisor stack with custom source address func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source net.Destination) (int, error) { - // Create a route from dest (our local) to source (remote) - var netProto tcpip.NetworkProtocolNumber - if dest.Address.Family().IsIPv4() { - netProto = header.IPv4ProtocolNumber - } else { + netProto := header.IPv4ProtocolNumber + if !dest.Address.Family().IsIPv4() { netProto = header.IPv6ProtocolNumber } - route, err := ipStack.FindRoute( - defaultNIC, - tcpip.AddrFromSlice(dest.Address.IP()), - tcpip.AddrFromSlice(source.Address.IP()), - netProto, - false, - ) + route, err := ipStack.FindRoute(defaultNIC, tcpip.AddrFromSlice(dest.Address.IP()), tcpip.AddrFromSlice(source.Address.IP()), netProto, false) if err != nil { return 0, errors.New("failed to find route: " + err.String()) } defer route.Release() - // Create packet buffer with UDP payload pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.UDPMinimumSize, Payload: buffer.MakeWithData(data), }) defer pkt.DecRef() - // Build UDP header - udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) length := uint16(pkt.Size()) + udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) udpHeader.Encode(&header.UDPFields{ SrcPort: uint16(dest.Port), DstPort: uint16(source.Port), Length: length, }) - // Calculate checksum xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, length) - xsum = checksum.Checksum(data, xsum) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) + udpHeader.SetChecksum(^udpHeader.CalculateChecksum(checksum.Checksum(data, xsum))) - // Write packet through route if err := route.WritePacket(stack.NetworkHeaderParams{ Protocol: header.UDPProtocolNumber, TTL: 64, @@ -267,80 +226,59 @@ func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source // HandleUDPPacket processes a raw UDP packet from gVisor func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) { - // Extract packet information - source := net.UDPDestination( - net.IPAddress(id.RemoteAddress.AsSlice()), - net.Port(id.RemotePort), - ) - dest := net.UDPDestination( - net.IPAddress(id.LocalAddress.AsSlice()), - net.Port(id.LocalPort), - ) + source := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort)) + dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)) - // Extract UDP payload data := pkt.Data().AsRange().ToSlice() if len(data) == 0 { return } - // Get or create connection for this source conn, existing := t.getUDPConn(source, dest, ipStack) - // Create buffer and set UDP destination b := buf.New() b.Write(data) b.UDP = &dest - - // Write to connection pipe conn.writer.WriteMultiBuffer(buf.MultiBuffer{b}) if !existing { - // Start checker for cleanup (only once) t.Lock() if t.udpChecker != nil && len(t.udpConns) == 1 { common.Must(t.udpChecker.Start()) } t.Unlock() - // Start handling this connection - go func() { - connID := udpConnID{ - src: source, - } - if !t.cone { - connID.dest = dest - } - - ctx, cancel := context.WithCancel(t.ctx) - conn.cancel = cancel - sid := session.NewID() - ctx = c.ContextWithID(ctx, sid) - - inbound := session.Inbound{} - inbound.Name = "tun" - inbound.Source = source - inbound.User = &protocol.MemoryUser{ - Level: t.config.UserLevel, - } - - ctx = session.ContextWithInbound(ctx, &inbound) - ctx = session.SubContextFromMuxInbound(ctx) - - link := &transport.Link{ - Reader: conn.reader, - Writer: buf.NewWriter(conn), - } - - if err := t.dispatcher.DispatchLink(ctx, dest, link); err != nil { - errors.LogError(ctx, errors.New("UDP connection ended").Base(err)) - } - - conn.Close() - if !conn.inactive { - conn.setInactive() - t.removeUDPConn(connID) - } - }() + go t.handleUDPConn(conn, source, dest) + } +} + +func (t *Handler) handleUDPConn(conn *udpConn, source, dest net.Destination) { + connID := udpConnID{src: source} + if !t.cone { + connID.dest = dest + } + + ctx, cancel := context.WithCancel(t.ctx) + conn.cancel = cancel + ctx = c.ContextWithID(ctx, session.NewID()) + ctx = session.ContextWithInbound(ctx, &session.Inbound{ + Name: "tun", + Source: source, + User: &protocol.MemoryUser{Level: t.config.UserLevel}, + }) + ctx = session.SubContextFromMuxInbound(ctx) + + if err := t.dispatcher.DispatchLink(ctx, dest, &transport.Link{ + Reader: conn.reader, + Writer: buf.NewWriter(conn), + }); err != nil { + errors.LogError(ctx, errors.New("UDP connection ended").Base(err)) + } + + conn.Close() + if !conn.inactive { + conn.inactive = true + t.removeUDPConn(connID) } }