Add flow userspace counters (#3438)

This commit is contained in:
Viktor Liu 2025-03-06 16:52:56 +01:00 committed by GitHub
parent b180edbe5c
commit 5ff77b3595
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 307 additions and 232 deletions

View File

@ -17,9 +17,11 @@ type BaseConnTrack struct {
Direction nftypes.Direction
SourceIP netip.Addr
DestIP netip.Addr
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64
PacketsTx atomic.Uint64
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
}
// these small methods will be inlined by the compiler
@ -29,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())

View File

@ -32,11 +32,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
}
}
})
@ -57,11 +57,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
}
}
})

View File

@ -68,7 +68,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
return tracker
}
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16) (ICMPConnKey, bool) {
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
key := ICMPConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@ -81,6 +81,7 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
@ -89,21 +90,22 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode) {
if _, exists := t.updateIfExists(dstIP, srcIP, id); !exists {
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress)
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress)
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, id)
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, size int) {
// TODO: icmp doesn't need to extend the timeout
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
@ -113,7 +115,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, size)
return
}
@ -138,7 +140,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8) bool {
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
@ -158,6 +160,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
@ -181,7 +184,8 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn)
}
}
@ -207,11 +211,15 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) {
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8) {
t.flowLogger.StoreEvent(nftypes.EventFields{
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, size int) {
fields := nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
Direction: direction,
@ -220,5 +228,13 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
DestIP: dstIP,
ICMPType: typ,
ICMPCode: code,
})
}
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
}

View File

@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
}
})
@ -28,12 +28,12 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0)
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
}
})
}

View File

@ -88,6 +88,8 @@ const (
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
State TCPState
established atomic.Bool
tombstone atomic.Bool
@ -144,7 +146,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@ -161,6 +163,8 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true
}
@ -168,21 +172,21 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, size)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, size)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists {
return
}
@ -193,9 +197,9 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.established.Store(false)
@ -212,7 +216,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) bool {
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
@ -239,6 +243,7 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn)
@ -427,7 +432,7 @@ func (t *TCPTracker) cleanup() {
// Return IPs to pool
delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", &key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
// event already handled by state change
if conn.State != TCPStateTimeWait {
@ -472,5 +477,9 @@ func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack) {
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
require.True(t, valid, "Data should be allowed after handshake")
},
},
@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
},
},
{
@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "Final ACKs should be allowed")
},
},
@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
},
wantValid: true,
desc: "Should accept RST for established connection",
@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
},
wantValid: false,
desc: "Should reject RST without connection",
@ -228,12 +228,12 @@ func TestRSTHandling(t *testing.T) {
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}
func BenchmarkTCPTracker(b *testing.B) {
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
@ -259,12 +259,12 @@ func BenchmarkTCPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
}
})
@ -279,9 +279,9 @@ func BenchmarkTCPTracker(b *testing.B) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
}
i++
}
@ -299,7 +299,7 @@ func BenchmarkCleanup(b *testing.B) {
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire

View File

@ -21,6 +21,8 @@ const (
// UDPConnTrack represents a UDP connection state
type UDPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
}
// UDPTracker manages UDP connection states
@ -54,19 +56,19 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, size)
}
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, size)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) (ConnKey, bool) {
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@ -80,6 +82,7 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
@ -87,8 +90,8 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
@ -99,9 +102,9 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.UpdateLastSeen()
@ -114,7 +117,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) bool {
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
@ -131,6 +134,7 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
@ -155,7 +159,8 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout)", key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn)
}
}
@ -201,5 +206,9 @@ func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack) {
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@ -48,7 +48,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
// Verify connection was tracked
key := ConnKey{
@ -76,7 +76,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tests := []struct {
name string
@ -148,7 +148,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
assert.Equal(t, tt.want, got)
})
}
@ -194,7 +194,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
}
// Verify initial connections
@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
}
})
@ -237,12 +237,12 @@ func BenchmarkUDPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
}
})
}

View File

@ -15,13 +15,16 @@ import (
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
flowID := uuid.New()
// Extract ICMP header to get type and code
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
@ -33,8 +36,6 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
@ -42,52 +43,36 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}()
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
switch icmpHdr.Type() {
case header.ICMPv4Echo:
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id, flowID)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
default:
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID, flowID uuid.UUID) bool {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return true
return
}
response := make([]byte, f.endpoint.mtu)
@ -96,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
}
return true
return
}
ipHdr := make([]byte, header.IPv4MinimumSize)
@ -117,13 +102,11 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return true
return
}
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return true
}
// sendICMPEvent stores flow events for ICMP packets
@ -138,5 +121,7 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
ICMPType: icmpType,
ICMPCode: icmpCode,
// TODO: get packets/bytes
})
}

View File

@ -22,7 +22,14 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id)
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
@ -51,6 +58,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID)
@ -66,7 +74,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}()
// Create context for managing the proxy goroutines
@ -98,17 +106,27 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
}
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 6,
Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}

View File

@ -165,13 +165,19 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
}
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id)
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
// TODO: Send ICMP error message
return
}
@ -184,7 +190,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return
}
@ -212,13 +217,14 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
}
@ -238,7 +244,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id)
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}()
errChan := make(chan error, 2)
@ -265,8 +271,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
}
// sendUDPEvent stores flow events for UDP connections, mirrors the TCP version
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
@ -276,7 +282,18 @@ func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.Tr
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}
func (c *udpPacketConn) updateLastSeen() {

View File

@ -281,7 +281,7 @@ func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder
}
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
msg := "No existing connection found"
if allowed {
msg = m.buildConntrackStateMessage(d)
@ -391,7 +391,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.processOutgoingHooks(packetData)
dropped := m.processOutgoingHooks(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {

View File

@ -510,13 +510,13 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
}
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.processOutgoingHooks(packetData)
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size)
}
// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData)
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData, size)
}
// UpdateLocalIPs updates the list of local IPs
@ -524,7 +524,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface)
}
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@ -544,7 +544,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
// Track all protocols if stateful mode is enabled
if m.stateful {
m.trackOutbound(d, srcIP, dstIP)
m.trackOutbound(d, srcIP, dstIP, size)
}
// Process UDP hooks even if stateful mode is disabled
@ -593,29 +593,29 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr) {
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode)
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
}
}
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr) {
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode)
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
}
}
@ -637,7 +637,7 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP netip.Addr, packetData []byte)
// dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte) bool {
func (m *Manager) dropFilter(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@ -653,12 +653,12 @@ func (m *Manager) dropFilter(packetData []byte) bool {
// For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false
}
if m.localipmanager.IsLocalIP(dstIP) {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
}
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
@ -666,7 +666,7 @@ func (m *Manager) dropFilter(packetData []byte) bool {
// handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked {
_, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
@ -685,6 +685,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
})
return true
}
@ -695,7 +697,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
}
// track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP)
m.trackInbound(d, srcIP, dstIP, size)
return false
}
@ -802,7 +804,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
return true
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) bool {
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound(
@ -811,6 +813,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr)
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
size,
)
case layers.LayerTypeUDP:
@ -819,6 +822,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr)
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
size,
)
case layers.LayerTypeICMPv4:
@ -827,6 +831,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr)
dstIP,
d.icmp4.Id,
d.icmp4.TypeCode.Type(),
size,
)
// TODO: ICMPv6

View File

@ -193,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection
if sc.stateful {
manager.processOutgoingHooks(outbound)
manager.processOutgoingHooks(outbound, 0)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound)
manager.dropFilter(inbound, 0)
}
})
}
@ -230,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound)
manager.processOutgoingHooks(outbound, 0)
}
// Test packet
@ -238,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.processOutgoingHooks(testOut)
manager.processOutgoingHooks(testOut, 0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(testIn)
manager.dropFilter(testIn, 0)
}
})
}
@ -278,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.processOutgoingHooks(outbound)
manager.processOutgoingHooks(outbound, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound)
manager.dropFilter(inbound, 0)
}
})
}
@ -477,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound)
manager.processOutgoingHooks(outbound, 0)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
manager.processOutgoingHooks(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack)
manager.dropFilter(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
manager.processOutgoingHooks(ack, 0)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound)
manager.dropFilter(inbound, 0)
}
})
}
@ -624,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
manager.processOutgoingHooks(syn, 0)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack)
manager.dropFilter(synack, 0)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
manager.processOutgoingHooks(ack, 0)
}
// Prepare test packets simulating bidirectional traffic
@ -655,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic
// First outbound data
manager.processOutgoingHooks(outPackets[connIdx])
manager.processOutgoingHooks(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx])
manager.dropFilter(inPackets[connIdx], 0)
}
})
}
@ -761,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Connection establishment
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
// Data transfer
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
// Connection teardown
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
}
})
}
@ -826,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
manager.processOutgoingHooks(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack)
manager.dropFilter(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
manager.processOutgoingHooks(ack, 0)
}
// Pre-generate test packets
@ -856,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++
// Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx])
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
}
})
})
@ -950,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Full connection lifecycle
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
}
})
})

View File

@ -192,7 +192,7 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet)
isDropped := manager.DropIncoming(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist")
})
@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet)
isDropped := manager.DropIncoming(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped)
})
}

View File

@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.dropFilter(buf.Bytes()) {
if m.dropFilter(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted")
return
}
@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err)
// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes())
result := manager.processOutgoingHooks(buf.Bytes(), 0)
require.True(t, result)
require.True(t, hookCalled)
@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes())
result = manager.processOutgoingHooks(buf.Bytes(), 0)
require.False(t, result)
}
@ -569,7 +569,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes())
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes())
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
// Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes())
drop = manager.dropFilter(testBuf.Bytes(), 0)
require.True(t, drop, tc.description)
})
}

View File

@ -11,10 +11,10 @@ import (
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte) bool
DropOutgoing(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool
DropIncoming(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
@ -58,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
@ -82,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:]) {
if !filter.DropIncoming(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf)
dropped++
}

View File

@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter

View File

@ -50,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
}
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
}
// RemovePacketHook mocks base method.

View File

@ -458,7 +458,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)

View File

@ -2,6 +2,7 @@ package types
import (
"net/netip"
"strconv"
"time"
"github.com/google/uuid"
@ -27,8 +28,10 @@ func (p Protocol) String() string {
return "TCP"
case 17:
return "UDP"
case 132:
return "SCTP"
default:
return "unknown"
return strconv.FormatUint(uint64(p), 10)
}
}