From fad82ee65c969aafc8214756a7447dce750ab847 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 30 Dec 2024 13:34:51 +0100 Subject: [PATCH] Add stop methods and improve udp implementation --- client/firewall/uspfilter/allow_netbird.go | 4 + .../uspfilter/allow_netbird_windows.go | 4 + .../firewall/uspfilter/forwarder/forwarder.go | 21 ++ client/firewall/uspfilter/forwarder/tcp.go | 42 ++-- client/firewall/uspfilter/forwarder/udp.go | 217 ++++++++++++------ 5 files changed, 205 insertions(+), 83 deletions(-) diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index cc0792255..297095090 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -30,6 +30,10 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) } + if m.forwarder != nil { + m.forwarder.Stop() + } + if m.nativeFirewall != nil { return m.nativeFirewall.Reset(stateManager) } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 0d55d6268..42bf0896e 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -42,6 +42,10 @@ func (m *Manager) Reset(*statemanager.Manager) error { m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) } + if m.forwarder != nil { + m.forwarder.Stop() + } + if !isWindowsFirewallReachable() { return nil } diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 4554ebb20..815c7da09 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -1,6 +1,7 @@ package forwarder import ( + "context" "fmt" log "github.com/sirupsen/logrus" @@ -24,9 +25,12 @@ type Forwarder struct { stack *stack.Stack endpoint *endpoint udpForwarder *udpForwarder + ctx context.Context + cancel context.CancelFunc } func New(iface common.IFaceMapper) (*Forwarder, error) { + s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{ @@ -85,10 +89,13 @@ func New(iface common.IFaceMapper) (*Forwarder, error) { }, }) + ctx, cancel := context.WithCancel(context.Background()) f := &Forwarder{ stack: s, endpoint: endpoint, udpForwarder: newUDPForwarder(), + ctx: ctx, + cancel: cancel, } // Set up TCP forwarder @@ -118,3 +125,17 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error { } return nil } + +// Stop gracefully shuts down the forwarder +func (f *Forwarder) Stop() error { + f.cancel() + + if f.udpForwarder != nil { + f.udpForwarder.Stop() + } + + f.stack.Close() + f.stack.Wait() + + return nil +} diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 4f406dea5..90967b6f5 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -1,10 +1,10 @@ package forwarder import ( + "context" "fmt" "io" "net" - "sync" log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -20,9 +20,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { dstPort := id.LocalPort dialAddr := fmt.Sprintf("%s:%d", dstAddr.String(), dstPort) - // Dial the destination first - dialer := net.Dialer{} - outConn, err := dialer.Dial("tcp", dialAddr) + outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { r.Complete(true) return @@ -40,8 +38,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { return } - // Now that we've successfully connected to the destination, - // we can complete the incoming connection + // Complete the handshake r.Complete(false) inConn := gonet.NewTCPConn(&wq, ep) @@ -59,24 +56,35 @@ func (f *Forwarder) proxyTCP(inConn *gonet.TCPConn, outConn net.Conn) { } }() - var wg sync.WaitGroup - wg.Add(2) + // Create context for managing the proxy goroutines + ctx, cancel := context.WithCancel(f.ctx) + defer cancel() + + errChan := make(chan error, 2) go func() { - defer wg.Done() - _, err := io.Copy(outConn, inConn) - if err != nil { - log.Errorf("proxyTCP: copy error: %v", err) + n, err := io.Copy(outConn, inConn) + if err != nil && !isClosedError(err) { + log.Errorf("proxyTCP: inbound->outbound copy error after %d bytes: %v", n, err) } + errChan <- err }() go func() { - defer wg.Done() - _, err := io.Copy(inConn, outConn) - if err != nil { - log.Errorf("proxyTCP: copy error: %v", err) + n, err := io.Copy(inConn, outConn) + if err != nil && !isClosedError(err) { + log.Errorf("proxyTCP: outbound->inbound copy error after %d bytes: %v", n, err) } + errChan <- err }() - wg.Wait() + select { + case <-ctx.Done(): + return + case err := <-errChan: + if err != nil && !isClosedError(err) { + log.Errorf("proxyTCP: copy error: %v", err) + } + return + } } diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index 836d570cb..7d201024e 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -1,6 +1,8 @@ package forwarder import ( + "context" + "errors" "fmt" "net" "sync" @@ -8,49 +10,94 @@ import ( log "github.com/sirupsen/logrus" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) const ( - udpTimeout = 60 * time.Second + udpTimeout = 60 * time.Second + maxPacketSize = 65535 ) type udpPacketConn struct { conn *gonet.UDPConn outConn net.Conn lastTime time.Time + cancel context.CancelFunc } type udpForwarder struct { sync.RWMutex - conns map[string]*udpPacketConn + conns map[stack.TransportEndpointID]*udpPacketConn + bufPool sync.Pool + ctx context.Context + cancel context.CancelFunc } func newUDPForwarder() *udpForwarder { + ctx, cancel := context.WithCancel(context.Background()) f := &udpForwarder{ - conns: make(map[string]*udpPacketConn), + conns: make(map[stack.TransportEndpointID]*udpPacketConn), + ctx: ctx, + cancel: cancel, + bufPool: sync.Pool{ + New: func() any { + b := make([]byte, maxPacketSize) + return &b + }, + }, } go f.cleanup() return f } +// Stop stops the UDP forwarder and all active connections +func (f *udpForwarder) Stop() { + f.cancel() + + f.Lock() + defer f.Unlock() + + for id, conn := range f.conns { + conn.cancel() + if err := conn.conn.Close(); err != nil { + log.Errorf("forwarder: UDP conn close error for %v: %v", id, err) + } + if err := conn.outConn.Close(); err != nil { + log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err) + } + delete(f.conns, id) + } +} + // cleanup periodically removes idle UDP connections func (f *udpForwarder) cleanup() { ticker := time.NewTicker(time.Minute) defer ticker.Stop() - for range ticker.C { - f.Lock() - now := time.Now() - for addr, conn := range f.conns { - if now.Sub(conn.lastTime) > udpTimeout { - conn.conn.Close() - conn.outConn.Close() - delete(f.conns, addr) + for { + select { + case <-f.ctx.Done(): + return + case <-ticker.C: + f.Lock() + now := time.Now() + for id, conn := range f.conns { + if now.Sub(conn.lastTime) > udpTimeout { + conn.cancel() + if err := conn.conn.Close(); err != nil { + log.Errorf("forwarder: UDP conn close error for %v: %v", id, err) + } + if err := conn.outConn.Close(); err != nil { + log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err) + } + delete(f.conns, id) + log.Debugf("forwarder: cleaned up idle UDP connection %v", id) + } } + f.Unlock() } - f.Unlock() } } @@ -59,12 +106,17 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { id := r.ID() dstAddr := fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort) + if f.ctx.Err() != nil { + log.Debug("forwarder: context done, dropping UDP packet") + return + } + // Create wait queue for blocking syscalls wq := waiter.Queue{} ep, err := r.CreateEndpoint(&wq) if err != nil { - log.Errorf("Create UDP endpoint error: %v", err) + log.Errorf("forwarder: failed to create UDP endpoint: %v", err) return } @@ -72,82 +124,115 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { // Try to get existing connection or create a new one f.udpForwarder.Lock() - pConn, exists := f.udpForwarder.conns[dstAddr] + defer f.udpForwarder.Unlock() + + pConn, exists := f.udpForwarder.conns[id] if !exists { - outConn, err := net.Dial("udp", dstAddr) + outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) if err != nil { - f.udpForwarder.Unlock() if err := inConn.Close(); err != nil { - log.Errorf("forwader: UDP inConn close error: %v", err) + log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err) } - log.Errorf("forwarder> UDP dial error: %v", err) + log.Errorf("forwarder: UDP dial error for %v: %v", id, err) return } + connCtx, connCancel := context.WithCancel(f.ctx) pConn = &udpPacketConn{ conn: inConn, outConn: outConn, lastTime: time.Now(), + cancel: connCancel, } - f.udpForwarder.conns[dstAddr] = pConn + f.udpForwarder.conns[id] = pConn - go f.proxyUDP(pConn, dstAddr) + go f.proxyUDP(connCtx, pConn, id) } - f.udpForwarder.Unlock() } -func (f *Forwarder) proxyUDP(pConn *udpPacketConn, dstAddr string) { +func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID) { defer func() { + pConn.cancel() if err := pConn.conn.Close(); err != nil { - log.Errorf("forwarder: inConn close error: %v", err) + log.Errorf("forwarder: UDP inConn close error for %v: %v", id, err) } if err := pConn.outConn.Close(); err != nil { - log.Errorf("forwarder: outConn close error: %v", err) - } - }() - - var wg sync.WaitGroup - wg.Add(2) - - // Handle outbound to inbound traffic - go func() { - defer wg.Done() - f.copyUDP(pConn.conn, pConn.outConn, dstAddr, "outbound->inbound") - }() - - // Handle inbound to outbound traffic - go func() { - defer wg.Done() - f.copyUDP(pConn.outConn, pConn.conn, dstAddr, "inbound->outbound") - }() - - wg.Wait() - - // Clean up the connection from the map - f.udpForwarder.Lock() - delete(f.udpForwarder.conns, dstAddr) - f.udpForwarder.Unlock() -} - -func (f *Forwarder) copyUDP(dst net.Conn, src net.Conn, dstAddr, direction string) { - buffer := make([]byte, 65535) - for { - n, err := src.Read(buffer) - if err != nil { - log.Errorf("UDP %s read error: %v", direction, err) - return - } - - _, err = dst.Write(buffer[:n]) - if err != nil { - log.Errorf("UDP %s write error: %v", direction, err) - continue + log.Errorf("forwarder: UDP outConn close error for %v: %v", id, err) } f.udpForwarder.Lock() - if conn, ok := f.udpForwarder.conns[dstAddr]; ok { - conn.lastTime = time.Now() - } + delete(f.udpForwarder.conns, id) f.udpForwarder.Unlock() + }() + + errChan := make(chan error, 2) + + go func() { + errChan <- f.copyUDP(ctx, pConn.conn, pConn.outConn, id, "outbound->inbound") + }() + + go func() { + errChan <- f.copyUDP(ctx, pConn.outConn, pConn.conn, id, "inbound->outbound") + }() + + select { + case <-ctx.Done(): + return + case err := <-errChan: + if err != nil && !isClosedError(err) { + log.Errorf("forwader: UDP proxy error for %v: %v", id, err) + } + return } } + +func (f *Forwarder) copyUDP(ctx context.Context, dst net.Conn, src net.Conn, id stack.TransportEndpointID, direction string) error { + bufp := f.udpForwarder.bufPool.Get().(*[]byte) + defer f.udpForwarder.bufPool.Put(bufp) + buffer := *bufp + + if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil { + return fmt.Errorf("set read deadline: %w", err) + } + if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil { + return fmt.Errorf("set write deadline: %w", err) + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + n, err := src.Read(buffer) + if err != nil { + if isTimeout(err) { + continue + } + return fmt.Errorf("read from %s: %w", direction, err) + } + + _, err = dst.Write(buffer[:n]) + if err != nil { + return fmt.Errorf("write to %s: %w", direction, err) + } + + f.udpForwarder.Lock() + if conn, ok := f.udpForwarder.conns[id]; ok { + conn.lastTime = time.Now() + } + f.udpForwarder.Unlock() + } + } +} + +func isClosedError(err error) bool { + return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) +} + +func isTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + return false +}