diff --git a/proxy/tun/handler.go b/proxy/tun/handler.go index b6f9097e..85f8d1bd 100644 --- a/proxy/tun/handler.go +++ b/proxy/tun/handler.go @@ -28,14 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -type udpConn struct { - lastActive int64 - reader buf.Reader - writer buf.Writer - done *done.Instance - cancel context.CancelFunc -} - // Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing type Handler struct { sync.Mutex @@ -44,7 +36,7 @@ type Handler struct { stack Stack policyManager policy.Manager dispatcher routing.Dispatcher - udpConns map[net.Destination]*udpConn + udpConns map[net.Destination]*struct{ lastActive int64; reader buf.Reader; writer buf.Writer; done *done.Instance; cancel context.CancelFunc } udpChecker *task.Periodic } @@ -66,111 +58,71 @@ func (t *Handler) cleanupUDP() error { if len(t.udpConns) == 0 { return errors.New("no connections") } - now := time.Now().Unix() for src, conn := range t.udpConns { - if now-atomic.LoadInt64(&conn.lastActive) > 300 { - conn.cancel() - common.Must(conn.done.Close()) - common.Must(common.Close(conn.writer)) - delete(t.udpConns, src) + if time.Now().Unix()-atomic.LoadInt64(&conn.lastActive) > 300 { + conn.cancel(); common.Must(conn.done.Close()); common.Must(common.Close(conn.writer)); delete(t.udpConns, src) } } return nil } func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) { - src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort)) - dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)) - data := pkt.Data().AsRange().ToSlice() - if len(data) == 0 { - return - } - - t.Lock() - conn, found := t.udpConns[src] - if !found { - reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024)) - conn = &udpConn{reader: reader, writer: writer, done: done.New()} - t.udpConns[src] = conn - if t.udpChecker != nil && len(t.udpConns) == 1 { - common.Must(t.udpChecker.Start()) - } - t.Unlock() - - go func() { - ctx, cancel := context.WithCancel(t.ctx) - conn.cancel = cancel - defer func() { - cancel() - t.Lock() - delete(t.udpConns, src) - t.Unlock() - common.Must(conn.done.Close()) - common.Must(common.Close(conn.writer)) + src, dest := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort)), net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)) + if data := pkt.Data().AsRange().ToSlice(); len(data) > 0 { + t.Lock() + conn, found := t.udpConns[src] + if !found { + reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024)) + conn = &struct{ lastActive int64; reader buf.Reader; writer buf.Writer; done *done.Instance; cancel context.CancelFunc }{reader: reader, writer: writer, done: done.New()} + t.udpConns[src] = conn + if t.udpChecker != nil && len(t.udpConns) == 1 { + common.Must(t.udpChecker.Start()) + } + t.Unlock() + go func() { + ctx, cancel := context.WithCancel(t.ctx) + conn.cancel = cancel + defer func() { + cancel() + t.Lock() + delete(t.udpConns, src) + t.Unlock() + common.Must(conn.done.Close()) + common.Must(common.Close(conn.writer)) + }() + t.dispatcher.DispatchLink(c.ContextWithID(session.ContextWithInbound(ctx, &session.Inbound{Name: "tun", Source: src, User: &protocol.MemoryUser{Level: t.config.UserLevel}}), session.NewID()), dest, &transport.Link{Reader: conn.reader, Writer: &udpWriter{stack: ipStack, src: dest, dest: src}}) }() - - ctx = c.ContextWithID(ctx, session.NewID()) - ctx = session.ContextWithInbound(ctx, &session.Inbound{ - Name: "tun", Source: src, - User: &protocol.MemoryUser{Level: t.config.UserLevel}, - }) - - t.dispatcher.DispatchLink(ctx, dest, &transport.Link{ - Reader: conn.reader, - Writer: &udpWriter{stack: ipStack, src: dest, dest: src}, - }) - }() - } else { - atomic.StoreInt64(&conn.lastActive, time.Now().Unix()) - t.Unlock() + } else { + atomic.StoreInt64(&conn.lastActive, time.Now().Unix()) + t.Unlock() + } + b := buf.New() + b.Write(data) + b.UDP = &dest + conn.writer.WriteMultiBuffer(buf.MultiBuffer{b}) } - - b := buf.New() - b.Write(data) - b.UDP = &dest - conn.writer.WriteMultiBuffer(buf.MultiBuffer{b}) } -type udpWriter struct { - stack *stack.Stack - src net.Destination - dest net.Destination -} +type udpWriter struct{ stack *stack.Stack; src, dest net.Destination } func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { for _, b := range mb { if b.UDP != nil { w.src = *b.UDP } - netProto := header.IPv4ProtocolNumber if !w.src.Address.Family().IsIPv4() { netProto = header.IPv6ProtocolNumber } - - route, err := w.stack.FindRoute(defaultNIC, tcpip.AddrFromSlice(w.src.Address.IP()), tcpip.AddrFromSlice(w.dest.Address.IP()), netProto, false) - if err != nil { - b.Release() - continue + if route, err := w.stack.FindRoute(defaultNIC, tcpip.AddrFromSlice(w.src.Address.IP()), tcpip.AddrFromSlice(w.dest.Address.IP()), netProto, false); err == nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ReserveHeaderBytes: header.UDPMinimumSize, Payload: buffer.MakeWithData(b.Bytes())}) + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udp.Encode(&header.UDPFields{SrcPort: uint16(w.src.Port), DstPort: uint16(w.dest.Port), Length: uint16(pkt.Size())}) + udp.SetChecksum(^udp.CalculateChecksum(checksum.Checksum(b.Bytes(), route.PseudoHeaderChecksum(header.UDPProtocolNumber, uint16(pkt.Size()))))) + route.WritePacket(stack.NetworkHeaderParams{Protocol: header.UDPProtocolNumber, TTL: 64}, pkt) + pkt.DecRef() + route.Release() } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize, - Payload: buffer.MakeWithData(b.Bytes()), - }) - - udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - udp.Encode(&header.UDPFields{ - SrcPort: uint16(w.src.Port), - DstPort: uint16(w.dest.Port), - Length: uint16(pkt.Size()), - }) - xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, uint16(pkt.Size())) - udp.SetChecksum(^udp.CalculateChecksum(checksum.Checksum(b.Bytes(), xsum))) - - route.WritePacket(stack.NetworkHeaderParams{Protocol: header.UDPProtocolNumber, TTL: 64}, pkt) - pkt.DecRef() - route.Release() b.Release() } return nil @@ -178,53 +130,31 @@ func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { // Init the Handler instance with necessary parameters func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error { - var err error - - t.ctx = core.ToBackgroundDetachedContext(ctx) - t.policyManager = pm - t.dispatcher = dispatcher - t.udpConns = make(map[net.Destination]*udpConn) + t.ctx, t.policyManager, t.dispatcher = core.ToBackgroundDetachedContext(ctx), pm, dispatcher + t.udpConns = make(map[net.Destination]*struct{ lastActive int64; reader buf.Reader; writer buf.Writer; done *done.Instance; cancel context.CancelFunc }) t.udpChecker = &task.Periodic{Interval: time.Minute, Execute: t.cleanupUDP} - - tunName := t.config.Name - tunOptions := TunOptions{ - Name: tunName, - MTU: t.config.MTU, - } - tunInterface, err := NewTun(tunOptions) + tunInterface, err := NewTun(TunOptions{Name: t.config.Name, MTU: t.config.MTU}) if err != nil { return err } - - errors.LogInfo(t.ctx, tunName, " created") - - tunStackOptions := StackOptions{ - Tun: tunInterface, - IdleTimeout: pm.ForLevel(t.config.UserLevel).Timeouts.ConnectionIdle, - } - tunStack, err := NewStack(t.ctx, tunStackOptions, t) + errors.LogInfo(t.ctx, t.config.Name, " created") + tunStack, err := NewStack(t.ctx, StackOptions{Tun: tunInterface, IdleTimeout: pm.ForLevel(t.config.UserLevel).Timeouts.ConnectionIdle}, t) if err != nil { _ = tunInterface.Close() return err } - - err = tunStack.Start() - if err != nil { + if err = tunStack.Start(); err != nil { _ = tunStack.Close() _ = tunInterface.Close() return err } - - err = tunInterface.Start() - if err != nil { + if err = tunInterface.Start(); err != nil { _ = tunStack.Close() _ = tunInterface.Close() return err } - t.stack = tunStack - - errors.LogInfo(t.ctx, tunName, " up") + errors.LogInfo(t.ctx, t.config.Name, " up") return nil }