From 8e2d358564cdbd54387a15299e73e8881d45b6c8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 9 Jan 2026 12:48:20 +0000 Subject: [PATCH] Simplify UDP packet writing using Route.WritePacket Replace manual IP header construction with gVisor's Route API: - Use Stack.FindRoute() to create proper route - Use Route.WritePacket() with NetworkHeaderParams - Let gVisor handle IP header construction - Simpler and more maintainable code Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com> --- proxy/tun/handler.go | 114 +++++++++++++++---------------------------- 1 file changed, 38 insertions(+), 76 deletions(-) diff --git a/proxy/tun/handler.go b/proxy/tun/handler.go index 14c5d608..bd06919c 100644 --- a/proxy/tun/handler.go +++ b/proxy/tun/handler.go @@ -212,91 +212,53 @@ func (t *Handler) cleanupUDPConns() error { // writeUDPPacket writes a UDP packet back to the gVisor stack with custom source address func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source net.Destination) (int, error) { - // Build UDP+IP packet with proper headers using gVisor's header builders - - // Determine IP version - var ipHdrLen, udpHdrLen int - isIPv4 := dest.Address.Family().IsIPv4() - - if isIPv4 { - ipHdrLen = header.IPv4MinimumSize + // Create a route from dest (our local) to source (remote) + var netProto tcpip.NetworkProtocolNumber + if dest.Address.Family().IsIPv4() { + netProto = header.IPv4ProtocolNumber } else { - ipHdrLen = header.IPv6MinimumSize + netProto = header.IPv6ProtocolNumber } - udpHdrLen = header.UDPMinimumSize - totalLen := ipHdrLen + udpHdrLen + len(data) - packet := make([]byte, totalLen) + route, err := ipStack.FindRoute( + defaultNIC, + tcpip.AddrFromSlice(dest.Address.IP()), + tcpip.AddrFromSlice(source.Address.IP()), + netProto, + false, + ) + if err != nil { + return 0, errors.New("failed to find route: " + err.String()) + } + defer route.Release() + + // Create packet buffer with UDP payload + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize, + Payload: buffer.MakeWithData(data), + }) + defer pkt.DecRef() // Build UDP header - udpHeader := header.UDP(packet[ipHdrLen:]) + udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + length := uint16(pkt.Size()) udpHeader.Encode(&header.UDPFields{ - SrcPort: uint16(dest.Port), // Source is the original destination - DstPort: uint16(source.Port), // Destination is the original source - Length: uint16(udpHdrLen + len(data)), + SrcPort: uint16(dest.Port), + DstPort: uint16(source.Port), + Length: length, }) - // Copy payload - copy(packet[ipHdrLen+udpHdrLen:], data) + // Calculate checksum + xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, length) + xsum = checksum.Checksum(data, xsum) + udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - // Build IP header and calculate checksums - if isIPv4 { - ipv4Header := header.IPv4(packet) - ipv4Header.Encode(&header.IPv4Fields{ - TOS: 0, - TotalLength: uint16(totalLen), - ID: 0, - Flags: 0, - FragmentOffset: 0, - TTL: 64, - Protocol: uint8(header.UDPProtocolNumber), - SrcAddr: tcpip.AddrFromSlice(dest.Address.IP()), - DstAddr: tcpip.AddrFromSlice(source.Address.IP()), - }) - ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum()) - - // Calculate UDP checksum - xsum := header.PseudoHeaderChecksum( - header.UDPProtocolNumber, - tcpip.AddrFromSlice(dest.Address.IP()), - tcpip.AddrFromSlice(source.Address.IP()), - uint16(udpHdrLen+len(data)), - ) - xsum = checksum.Checksum(data, xsum) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } else { - ipv6Header := header.IPv6(packet) - ipv6Header.Encode(&header.IPv6Fields{ - TrafficClass: 0, - FlowLabel: 0, - PayloadLength: uint16(udpHdrLen + len(data)), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: 64, - SrcAddr: tcpip.AddrFromSlice(dest.Address.IP()), - DstAddr: tcpip.AddrFromSlice(source.Address.IP()), - }) - - // Calculate UDP checksum for IPv6 - xsum := header.PseudoHeaderChecksum( - header.UDPProtocolNumber, - tcpip.AddrFromSlice(dest.Address.IP()), - tcpip.AddrFromSlice(source.Address.IP()), - uint16(udpHdrLen+len(data)), - ) - xsum = checksum.Checksum(data, xsum) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } - - // Write packet to stack - var proto tcpip.NetworkProtocolNumber - if isIPv4 { - proto = header.IPv4ProtocolNumber - } else { - proto = header.IPv6ProtocolNumber - } - - buf := buffer.MakeWithData(packet) - if err := ipStack.WriteRawPacket(defaultNIC, proto, buf); err != nil { + // Write packet through route + if err := route.WritePacket(stack.NetworkHeaderParams{ + Protocol: header.UDPProtocolNumber, + TTL: 64, + TOS: 0, + }, pkt); err != nil { return 0, errors.New("failed to write packet: " + err.String()) }