From 6a97d44d5da1dd3a0219f69bf7fcaa7a050cbd4c Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 30 Dec 2024 20:50:20 +0100 Subject: [PATCH] Improve udp implementation --- client/firewall/uspfilter/forwarder/tcp.go | 8 +- client/firewall/uspfilter/forwarder/udp.go | 149 +++++++++++++-------- 2 files changed, 96 insertions(+), 61 deletions(-) diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 25503cb6d..263fffb51 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -26,13 +26,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { return } - f.logger.Trace("forwarder: established TCP connection to %v", id) - // Create wait queue for blocking syscalls wq := waiter.Queue{} - ep, err2 := r.CreateEndpoint(&wq) - if err2 != nil { + ep, epErr := r.CreateEndpoint(&wq) + if epErr != nil { if err := outConn.Close(); err != nil { f.logger.Error("forwarder: outConn close error: %v", err) } @@ -45,6 +43,8 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { inConn := gonet.NewTCPConn(&wq, ep) + f.logger.Trace("forwarder: established TCP connection to %v", id) + go f.proxyTCP(inConn, outConn) } diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index bb43a8346..a6f3ab993 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -6,8 +6,10 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" + log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -24,7 +26,7 @@ const ( type udpPacketConn struct { conn *gonet.UDPConn outConn net.Conn - lastTime time.Time + lastSeen atomic.Int64 cancel context.CancelFunc } @@ -84,22 +86,37 @@ func (f *udpForwarder) cleanup() { case <-f.ctx.Done(): return case <-ticker.C: - f.Lock() - now := time.Now() + var idleConns []struct { + id stack.TransportEndpointID + conn *udpPacketConn + } + + f.RLock() for id, conn := range f.conns { - if now.Sub(conn.lastTime) > udpTimeout { - conn.cancel() - if err := conn.conn.Close(); err != nil { - f.logger.Error("forwarder: UDP conn close error for %v: %v", id, err) - } - if err := conn.outConn.Close(); err != nil { - f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err) - } - delete(f.conns, id) - f.logger.Trace("forwarder: cleaned up idle UDP connection %v", id) + if conn.getIdleDuration() > udpTimeout { + idleConns = append(idleConns, struct { + id stack.TransportEndpointID + conn *udpPacketConn + }{id, conn}) } } - f.Unlock() + f.RUnlock() + + for _, idle := range idleConns { + idle.conn.cancel() + if err := idle.conn.conn.Close(); err != nil { + f.logger.Error("forwarder: UDP conn close error for %v: %v", idle.id, err) + } + if err := idle.conn.outConn.Close(); err != nil { + f.logger.Error("forwarder: UDP outConn close error for %v: %v", idle.id, err) + } + + f.Lock() + delete(f.conns, idle.id) + f.Unlock() + + f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id) + } } } } @@ -114,47 +131,60 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { return } + f.udpForwarder.RLock() + pConn, exists := f.udpForwarder.conns[id] + f.udpForwarder.RUnlock() + if exists { + f.logger.Trace("forwarder: existing UDP connection for %v", id) + return + } + + outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) + if err != nil { + f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) + // TODO: Send ICMP error message + return + } + // Create wait queue for blocking syscalls wq := waiter.Queue{} - - ep, err := r.CreateEndpoint(&wq) - if err != nil { - f.logger.Error("forwarder: failed to create UDP endpoint: %v", err) + ep, epErr := r.CreateEndpoint(&wq) + if epErr != nil { + f.logger.Error("forwarder: failed to create UDP endpoint: %v", epErr) + if err := outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) + } return } inConn := gonet.NewUDPConn(f.stack, &wq, ep) + connCtx, connCancel := context.WithCancel(f.ctx) - // Try to get existing connection or create a new one - f.udpForwarder.Lock() - defer f.udpForwarder.Unlock() - - pConn, exists := f.udpForwarder.conns[id] - if !exists { - outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) - if err != nil { - if err := inConn.Close(); err != nil { - f.logger.Error("forwarder: UDP inConn close error for %v: %v", id, err) - } - f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) - - // TODO: Send ICMP error message - return - } - - f.logger.Trace("forwarder: established UDP connection to %v", id) - - connCtx, connCancel := context.WithCancel(f.ctx) - pConn = &udpPacketConn{ - conn: inConn, - outConn: outConn, - lastTime: time.Now(), - cancel: connCancel, - } - f.udpForwarder.conns[id] = pConn - - go f.proxyUDP(connCtx, pConn, id) + pConn = &udpPacketConn{ + conn: inConn, + outConn: outConn, + cancel: connCancel, } + pConn.updateLastSeen() + + f.udpForwarder.Lock() + // Double-check no connection was created while we were setting up + if _, exists := f.udpForwarder.conns[id]; exists { + f.udpForwarder.Unlock() + pConn.cancel() + if err := inConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) + } + if err := outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) + } + return + } + f.udpForwarder.conns[id] = pConn + f.udpForwarder.Unlock() + + f.logger.Trace("forwarder: established UDP connection to %v", id) + go f.proxyUDP(connCtx, pConn, id) } func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) { @@ -175,11 +205,11 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack errChan := make(chan error, 2) go func() { - errChan <- f.copyUDP(ctx, pConn.conn, pConn.outConn, id, "outbound->inbound") + errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") }() go func() { - errChan <- f.copyUDP(ctx, pConn.outConn, pConn.conn, id, "inbound->outbound") + errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") }() select { @@ -193,9 +223,18 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack } } -func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id stack.TransportEndpointID, direction string) error { - bufp := f.udpForwarder.bufPool.Get().(*[]byte) - defer f.udpForwarder.bufPool.Put(bufp) +func (c *udpPacketConn) updateLastSeen() { + c.lastSeen.Store(time.Now().UnixNano()) +} + +func (c *udpPacketConn) getIdleDuration() time.Duration { + lastSeen := time.Unix(0, c.lastSeen.Load()) + return time.Since(lastSeen) +} + +func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error { + bufp := bufPool.Get().(*[]byte) + defer bufPool.Put(bufp) buffer := *bufp if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil { @@ -223,11 +262,7 @@ func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id return fmt.Errorf("write to %s: %w", direction, err) } - f.udpForwarder.Lock() - if conn, ok := f.udpForwarder.conns[id]; ok { - conn.lastTime = time.Now() - } - f.udpForwarder.Unlock() + c.updateLastSeen() } } }