diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index e805ea491..bf5320fe1 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -6,6 +6,7 @@ import ( "io" "net" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -32,6 +33,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { ep, epErr := r.CreateEndpoint(&wq) if epErr != nil { + f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr) if err := outConn.Close(); err != nil { f.logger.Error("forwarder: outConn close error: %v", err) } @@ -44,12 +46,12 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { inConn := gonet.NewTCPConn(&wq, ep) - f.logger.Trace("forwarder: established TCP connection to %v", id) + f.logger.Trace("forwarder: established TCP connection %v", id) - go f.proxyTCP(id, inConn, outConn) + go f.proxyTCP(id, inConn, outConn, ep) } -func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn) { +func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) { defer func() { if err := inConn.Close(); err != nil { f.logger.Error("forwarder: inConn close error: %v", err) @@ -57,6 +59,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn if err := outConn.Close(); err != nil { f.logger.Error("forwarder: outConn close error: %v", err) } + ep.Close() }() // Create context for managing the proxy goroutines diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 4491b0135..85094baad 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -27,6 +28,7 @@ type udpPacketConn struct { outConn net.Conn lastSeen atomic.Int64 cancel context.CancelFunc + ep tcpip.Endpoint } type udpForwarder struct { @@ -38,6 +40,11 @@ type udpForwarder struct { cancel context.CancelFunc } +type idleConn struct { + id stack.TransportEndpointID + conn *udpPacketConn +} + func newUDPForwarder(logger *nblog.Logger) *udpForwarder { ctx, cancel := context.WithCancel(context.Background()) f := &udpForwarder{ @@ -85,18 +92,12 @@ func (f *udpForwarder) cleanup() { case <-f.ctx.Done(): return case <-ticker.C: - var idleConns []struct { - id stack.TransportEndpointID - conn *udpPacketConn - } + var idleConns []idleConn f.RLock() for id, conn := range f.conns { if conn.getIdleDuration() > udpTimeout { - idleConns = append(idleConns, struct { - id stack.TransportEndpointID - conn *udpPacketConn - }{id, conn}) + idleConns = append(idleConns, idleConn{id, conn}) } } f.RUnlock() @@ -110,6 +111,8 @@ func (f *udpForwarder) cleanup() { f.logger.Error("forwarder: UDP outConn close error for %v: %v", idle.id, err) } + idle.conn.ep.Close() + f.Lock() delete(f.conns, idle.id) f.Unlock() @@ -163,6 +166,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { conn: inConn, outConn: outConn, cancel: connCancel, + ep: ep, } pConn.updateLastSeen() @@ -183,10 +187,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { f.udpForwarder.Unlock() f.logger.Trace("forwarder: established UDP connection to %v", id) - go f.proxyUDP(connCtx, pConn, id) + go f.proxyUDP(connCtx, pConn, id, ep) } -func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) { +func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { defer func() { pConn.cancel() if err := pConn.conn.Close(); err != nil { @@ -196,6 +200,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack f.logger.Error("forwarder: UDP outConn close error for %v: %v", id, err) } + ep.Close() + f.udpForwarder.Lock() delete(f.udpForwarder.conns, id) f.udpForwarder.Unlock()