diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 25cd9e87d..a8cb01565 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -1,6 +1,7 @@ package conntrack import ( + "context" "net" "sync" "time" @@ -39,8 +40,8 @@ type ICMPTracker struct { connections map[ICMPConnKey]*ICMPConnTrack timeout time.Duration cleanupTicker *time.Ticker + tickerCancel context.CancelFunc mutex sync.RWMutex - done chan struct{} ipPool *PreallocatedIPs } @@ -50,16 +51,18 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { timeout = DefaultICMPTimeout } + ctx, cancel := context.WithCancel(context.Background()) + tracker := &ICMPTracker{ logger: logger, connections: make(map[ICMPConnKey]*ICMPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), - done: make(chan struct{}), + tickerCancel: cancel, ipPool: NewPreallocatedIPs(), } - go tracker.cleanupRoutine() + go tracker.cleanupRoutine(ctx) return tracker } @@ -119,12 +122,14 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq conn.Sequence == seq } -func (t *ICMPTracker) cleanupRoutine() { +func (t *ICMPTracker) cleanupRoutine(ctx context.Context) { + defer t.tickerCancel() + for { select { case <-t.cleanupTicker.C: t.cleanup() - case <-t.done: + case <-ctx.Done(): return } } @@ -146,8 +151,7 @@ func (t *ICMPTracker) cleanup() { // Close stops the cleanup routine and releases resources func (t *ICMPTracker) Close() { - t.cleanupTicker.Stop() - close(t.done) + t.tickerCancel() t.mutex.Lock() for _, conn := range t.connections { diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 7c12e8ad0..1b5cbae95 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -3,6 +3,7 @@ package conntrack // TODO: Send RST packets for invalid/timed-out connections import ( + "context" "net" "sync" "sync/atomic" @@ -85,23 +86,26 @@ type TCPTracker struct { connections map[ConnKey]*TCPConnTrack mutex sync.RWMutex cleanupTicker *time.Ticker - done chan struct{} + tickerCancel context.CancelFunc timeout time.Duration ipPool *PreallocatedIPs } // NewTCPTracker creates a new TCP connection tracker func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { + + ctx, cancel := context.WithCancel(context.Background()) + tracker := &TCPTracker{ logger: logger, connections: make(map[ConnKey]*TCPConnTrack), cleanupTicker: time.NewTicker(TCPCleanupInterval), - done: make(chan struct{}), + tickerCancel: cancel, timeout: timeout, ipPool: NewPreallocatedIPs(), } - go tracker.cleanupRoutine() + go tracker.cleanupRoutine(ctx) return tracker } @@ -315,12 +319,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { return false } -func (t *TCPTracker) cleanupRoutine() { +func (t *TCPTracker) cleanupRoutine(ctx context.Context) { + defer t.cleanupTicker.Stop() + for { select { case <-t.cleanupTicker.C: t.cleanup() - case <-t.done: + case <-ctx.Done(): return } } @@ -355,8 +361,7 @@ func (t *TCPTracker) cleanup() { // Close stops the cleanup routine and releases resources func (t *TCPTracker) Close() { - t.cleanupTicker.Stop() - close(t.done) + t.tickerCancel() // Clean up all remaining IPs t.mutex.Lock() diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index e73465e31..073eb0fa2 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -1,6 +1,7 @@ package conntrack import ( + "context" "net" "sync" "time" @@ -26,8 +27,8 @@ type UDPTracker struct { connections map[ConnKey]*UDPConnTrack timeout time.Duration cleanupTicker *time.Ticker + tickerCancel context.CancelFunc mutex sync.RWMutex - done chan struct{} ipPool *PreallocatedIPs } @@ -37,16 +38,18 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { timeout = DefaultUDPTimeout } + ctx, cancel := context.WithCancel(context.Background()) + tracker := &UDPTracker{ logger: logger, connections: make(map[ConnKey]*UDPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), - done: make(chan struct{}), + tickerCancel: cancel, ipPool: NewPreallocatedIPs(), } - go tracker.cleanupRoutine() + go tracker.cleanupRoutine(ctx) return tracker } @@ -103,12 +106,14 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, } // cleanupRoutine periodically removes stale connections -func (t *UDPTracker) cleanupRoutine() { +func (t *UDPTracker) cleanupRoutine(ctx context.Context) { + defer t.cleanupTicker.Stop() + for { select { case <-t.cleanupTicker.C: t.cleanup() - case <-t.done: + case <-ctx.Done(): return } } @@ -131,8 +136,7 @@ func (t *UDPTracker) cleanup() { // Close stops the cleanup routine and releases resources func (t *UDPTracker) Close() { - t.cleanupTicker.Stop() - close(t.done) + t.tickerCancel() t.mutex.Lock() for _, conn := range t.connections { diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index fa83ee356..40e73cbe0 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -1,6 +1,7 @@ package conntrack import ( + "context" "net" "testing" "time" @@ -34,7 +35,7 @@ func TestNewUDPTracker(t *testing.T) { assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.cleanupTicker) - assert.NotNil(t, tracker.done) + assert.NotNil(t, tracker.tickerCancel) }) } } @@ -154,18 +155,21 @@ func TestUDPTracker_Cleanup(t *testing.T) { timeout := 50 * time.Millisecond cleanupInterval := 25 * time.Millisecond + ctx, tickerCancel := context.WithCancel(context.Background()) + defer tickerCancel() + // Create tracker with custom cleanup interval tracker := &UDPTracker{ connections: make(map[ConnKey]*UDPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(cleanupInterval), - done: make(chan struct{}), + tickerCancel: tickerCancel, ipPool: NewPreallocatedIPs(), logger: logger, } // Start cleanup routine - go tracker.cleanupRoutine() + go tracker.cleanupRoutine(ctx) // Add some connections connections := []struct {