[client] Prevent panic in case of double close call (#3475)

Prevent panic in case of double close call
This commit is contained in:
Zoltan Papp 2025-03-10 13:16:28 +01:00 committed by GitHub
parent 81040ff80a
commit 6bef474e9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 24 deletions

View File

@ -1,6 +1,7 @@
package conntrack package conntrack
import ( import (
"context"
"net" "net"
"sync" "sync"
"time" "time"
@ -39,8 +40,8 @@ type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs ipPool *PreallocatedIPs
} }
@ -50,16 +51,18 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{ tracker := &ICMPTracker{
logger: logger, logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack), connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
@ -119,12 +122,14 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
conn.Sequence == seq conn.Sequence == seq
} }
func (t *ICMPTracker) cleanupRoutine() { func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@ -146,8 +151,7 @@ func (t *ICMPTracker) cleanup() {
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() { func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections { for _, conn := range t.connections {

View File

@ -3,6 +3,7 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections // TODO: Send RST packets for invalid/timed-out connections
import ( import (
"context"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -85,23 +86,26 @@ type TCPTracker struct {
connections map[ConnKey]*TCPConnTrack connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex mutex sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
done chan struct{} tickerCancel context.CancelFunc
timeout time.Duration timeout time.Duration
ipPool *PreallocatedIPs ipPool *PreallocatedIPs
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
timeout: timeout, timeout: timeout,
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
@ -315,12 +319,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return false return false
} }
func (t *TCPTracker) cleanupRoutine() { func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@ -355,8 +361,7 @@ func (t *TCPTracker) cleanup() {
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() { func (t *TCPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
// Clean up all remaining IPs // Clean up all remaining IPs
t.mutex.Lock() t.mutex.Lock()

View File

@ -1,6 +1,7 @@
package conntrack package conntrack
import ( import (
"context"
"net" "net"
"sync" "sync"
"time" "time"
@ -26,8 +27,8 @@ type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack connections map[ConnKey]*UDPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs ipPool *PreallocatedIPs
} }
@ -37,16 +38,18 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{ tracker := &UDPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
@ -103,12 +106,14 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
} }
// cleanupRoutine periodically removes stale connections // cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() { func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@ -131,8 +136,7 @@ func (t *UDPTracker) cleanup() {
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() { func (t *UDPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections { for _, conn := range t.connections {

View File

@ -1,6 +1,7 @@
package conntrack package conntrack
import ( import (
"context"
"net" "net"
"testing" "testing"
"time" "time"
@ -34,7 +35,7 @@ func TestNewUDPTracker(t *testing.T) {
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker) assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.done) assert.NotNil(t, tracker.tickerCancel)
}) })
} }
} }
@ -154,18 +155,21 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval // Create tracker with custom cleanup interval
tracker := &UDPTracker{ tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}), tickerCancel: tickerCancel,
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
logger: logger, logger: logger,
} }
// Start cleanup routine // Start cleanup routine
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
// Add some connections // Add some connections
connections := []struct { connections := []struct {