[client] Fix inbound tracking in userspace firewall (#3111)

* Don't create state for inbound SYN

* Allow final ack in some cases

* Relax state machine test a little
This commit is contained in:
Viktor Liu 2024-12-26 00:51:27 +01:00 committed by GitHub
parent 0dbaddc7be
commit b3c87cb5d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 38 deletions

View File

@ -10,7 +10,6 @@ import (
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
sync.RWMutex
SourceIP net.IP SourceIP net.IP
DestIP net.IP DestIP net.IP
SourcePort uint16 SourcePort uint16

View File

@ -62,6 +62,7 @@ type TCPConnKey struct {
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
State TCPState State TCPState
sync.RWMutex
} }
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
@ -131,36 +132,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false return false
} }
// Handle new SYN packets
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort) key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.Lock()
if _, exists := t.connections[key]; !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, dstIP)
copyIP(dstIPCopy, srcIP)
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: dstPort,
DestPort: srcPort,
},
State: TCPStateSynReceived,
}
conn.lastSeen.Store(time.Now().UnixNano())
conn.established.Store(false)
t.connections[key] = conn
}
t.mutex.Unlock()
return true
}
// Look up existing connection
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
@ -172,8 +145,7 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
// Handle RST packets // Handle RST packets
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
conn.Lock() conn.Lock()
isEstablished := conn.IsEstablished() if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetEstablished(false) conn.SetEstablished(false)
conn.Unlock() conn.Unlock()
@ -183,7 +155,6 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false return false
} }
// Update state
conn.Lock() conn.Lock()
t.updateState(conn, flags, false) t.updateState(conn, flags, false)
conn.UpdateLastSeen() conn.UpdateLastSeen()
@ -306,6 +277,11 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return flags&TCPFin != 0 || flags&TCPAck != 0 return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateLastAck: case TCPStateLastAck:
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateClosed:
// Accept retransmitted ACKs in closed state
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0
} }
return false return false
} }

View File

@ -125,11 +125,8 @@ func TestTCPStateMachine(t *testing.T) {
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
require.True(t, valid, "RST should be allowed for established connection") require.True(t, valid, "RST should be allowed for established connection")
// Verify connection is closed // Connection is logically dead but we don't enforce blocking subsequent packets
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) // The connection will be cleaned up by timeout
t.Helper()
require.False(t, valid, "Data should be blocked after RST")
}, },
}, },
{ {