From 474fb3330584a66f2aa0707074623a2872c61fcf Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 4 Jan 2025 20:52:26 +0100 Subject: [PATCH] Remove established field from udp and icmp (unused) --- client/firewall/uspfilter/conntrack/common.go | 21 ++++----------- .../uspfilter/conntrack/common_test.go | 26 ------------------- client/firewall/uspfilter/conntrack/icmp.go | 9 +++---- client/firewall/uspfilter/conntrack/tcp.go | 19 +++++++++++--- client/firewall/uspfilter/conntrack/udp.go | 9 +++---- 5 files changed, 26 insertions(+), 58 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index e459bc75a..f5f502540 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -10,12 +10,11 @@ import ( // BaseConnTrack provides common fields and locking for all connection types type BaseConnTrack struct { - SourceIP net.IP - DestIP net.IP - SourcePort uint16 - DestPort uint16 - lastSeen atomic.Int64 // Unix nano for atomic access - established atomic.Bool + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + lastSeen atomic.Int64 // Unix nano for atomic access } // these small methods will be inlined by the compiler @@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() { b.lastSeen.Store(time.Now().UnixNano()) } -// IsEstablished safely checks if connection is established -func (b *BaseConnTrack) IsEstablished() bool { - return b.established.Load() -} - -// SetEstablished safely sets the established state -func (b *BaseConnTrack) SetEstablished(state bool) { - b.established.Store(state) -} - // GetLastSeen safely gets the last seen timestamp func (b *BaseConnTrack) GetLastSeen() time.Time { return time.Unix(0, b.lastSeen.Load()) diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index 152e4e482..81fa64b19 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -40,32 +40,6 @@ func BenchmarkIPOperations(b *testing.B) { }) } -func BenchmarkAtomicOperations(b *testing.B) { - conn := &BaseConnTrack{} - b.Run("UpdateLastSeen", func(b *testing.B) { - for i := 0; i < b.N; i++ { - conn.UpdateLastSeen() - } - }) - - b.Run("IsEstablished", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = conn.IsEstablished() - } - }) - - b.Run("SetEstablished", func(b *testing.B) { - for i := 0; i < b.N; i++ { - conn.SetEstablished(i%2 == 0) - } - }) - - b.Run("GetLastSeen", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = conn.GetLastSeen() - } - }) -} // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index f9f5fad69..7796b0908 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -66,7 +66,6 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { // TrackOutbound records an outbound ICMP Echo Request func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { key := makeICMPKey(srcIP, dstIP, id, seq) - now := time.Now().UnixNano() t.mutex.Lock() conn, exists := t.connections[key] @@ -84,15 +83,14 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u ID: id, Sequence: seq, } - conn.lastSeen.Store(now) - conn.established.Store(true) + conn.UpdateLastSeen() t.connections[key] = conn t.logger.Trace("New ICMP connection %v", key) } t.mutex.Unlock() - conn.lastSeen.Store(now) + conn.UpdateLastSeen() } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request @@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq return false } - return conn.IsEstablished() && - ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.ID == id && conn.Sequence == seq diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 3f082d32c..7c12e8ad0 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -5,6 +5,7 @@ package conntrack import ( "net" "sync" + "sync/atomic" "time" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" @@ -63,10 +64,21 @@ type TCPConnKey struct { // TCPConnTrack represents a TCP connection state type TCPConnTrack struct { BaseConnTrack - State TCPState + State TCPState + established atomic.Bool sync.RWMutex } +// IsEstablished safely checks if connection is established +func (t *TCPConnTrack) IsEstablished() bool { + return t.established.Load() +} + +// SetEstablished safely sets the established state +func (t *TCPConnTrack) SetEstablished(state bool) { + t.established.Store(state) +} + // TCPTracker manages TCP connection states type TCPTracker struct { logger *nblog.Logger @@ -97,7 +109,6 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { // Create key before lock key := makeConnKey(srcIP, dstIP, srcPort, dstPort) - now := time.Now().UnixNano() t.mutex.Lock() conn, exists := t.connections[key] @@ -117,7 +128,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d }, State: TCPStateNew, } - conn.lastSeen.Store(now) + conn.UpdateLastSeen() conn.established.Store(false) t.connections[key] = conn @@ -129,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d conn.Lock() t.updateState(conn, flags, true) conn.Unlock() - conn.lastSeen.Store(now) + conn.UpdateLastSeen() } // IsValidInbound checks if an inbound TCP packet matches a tracked connection diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 5e37f5ceb..7e030f7f2 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -53,7 +53,6 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { // TrackOutbound records an outbound UDP connection func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { key := makeConnKey(srcIP, dstIP, srcPort, dstPort) - now := time.Now().UnixNano() t.mutex.Lock() conn, exists := t.connections[key] @@ -71,15 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d DestPort: dstPort, }, } - conn.lastSeen.Store(now) - conn.established.Store(true) + conn.UpdateLastSeen() t.connections[key] = conn t.logger.Trace("New UDP connection: %v", conn) } t.mutex.Unlock() - conn.lastSeen.Store(now) + conn.UpdateLastSeen() } // IsValidInbound checks if an inbound packet matches a tracked connection @@ -98,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } - return conn.IsEstablished() && - ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.DestPort == srcPort && conn.SourcePort == dstPort