diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index a4b1971bf..e459bc75a 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -10,7 +10,6 @@ import ( // BaseConnTrack provides common fields and locking for all connection types type BaseConnTrack struct { - sync.RWMutex SourceIP net.IP DestIP net.IP SourcePort uint16 diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index e8d20f41c..a7968dc73 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -62,6 +62,7 @@ type TCPConnKey struct { type TCPConnTrack struct { BaseConnTrack State TCPState + sync.RWMutex } // TCPTracker manages TCP connection states @@ -131,36 +132,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } - // Handle new SYN packets - if flags&TCPSyn != 0 && flags&TCPAck == 0 { - key := makeConnKey(dstIP, srcIP, dstPort, srcPort) - t.mutex.Lock() - if _, exists := t.connections[key]; !exists { - // Use preallocated IPs - srcIPCopy := t.ipPool.Get() - dstIPCopy := t.ipPool.Get() - copyIP(srcIPCopy, dstIP) - copyIP(dstIPCopy, srcIP) - - conn := &TCPConnTrack{ - BaseConnTrack: BaseConnTrack{ - SourceIP: srcIPCopy, - DestIP: dstIPCopy, - SourcePort: dstPort, - DestPort: srcPort, - }, - State: TCPStateSynReceived, - } - conn.lastSeen.Store(time.Now().UnixNano()) - conn.established.Store(false) - t.connections[key] = conn - } - t.mutex.Unlock() - return true - } - - // Look up existing connection key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.RLock() conn, exists := t.connections[key] t.mutex.RUnlock() @@ -172,8 +145,7 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, // Handle RST packets if flags&TCPRst != 0 { conn.Lock() - isEstablished := conn.IsEstablished() - if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { + if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { conn.State = TCPStateClosed conn.SetEstablished(false) conn.Unlock() @@ -183,7 +155,6 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } - // Update state conn.Lock() t.updateState(conn, flags, false) conn.UpdateLastSeen() @@ -306,6 +277,11 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { return flags&TCPFin != 0 || flags&TCPAck != 0 case TCPStateLastAck: return flags&TCPAck != 0 + case TCPStateClosed: + // Accept retransmitted ACKs in closed state + // This is important because the final ACK might be lost + // and the peer will retransmit their FIN-ACK + return flags&TCPAck != 0 } return false } diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 3933c8889..6c8f82423 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -125,11 +125,8 @@ func TestTCPStateMachine(t *testing.T) { valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) require.True(t, valid, "RST should be allowed for established connection") - // Verify connection is closed - valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) - t.Helper() - - require.False(t, valid, "Data should be blocked after RST") + // Connection is logically dead but we don't enforce blocking subsequent packets + // The connection will be cleaned up by timeout }, }, {