Add flow userspace counters (#3438)

This commit is contained in:
Viktor Liu 2025-03-06 16:52:56 +01:00 committed by GitHub
parent b180edbe5c
commit 5ff77b3595
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 307 additions and 232 deletions

View File

@ -17,9 +17,11 @@ type BaseConnTrack struct {
Direction nftypes.Direction Direction nftypes.Direction
SourceIP netip.Addr SourceIP netip.Addr
DestIP netip.Addr DestIP netip.Addr
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64 lastSeen atomic.Int64
PacketsTx atomic.Uint64
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler
@ -29,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano()) b.lastSeen.Store(time.Now().UnixNano())
} }
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp // GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time { func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load()) return time.Unix(0, b.lastSeen.Load())

View File

@ -32,11 +32,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
} }
} }
}) })
@ -57,11 +57,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
} }
} }
}) })

View File

@ -68,7 +68,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
return tracker return tracker
} }
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16) (ICMPConnKey, bool) { func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
key := ICMPConnKey{ key := ICMPConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@ -81,6 +81,7 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
if exists { if exists {
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true return key, true
} }
@ -89,21 +90,22 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
} }
// TrackOutbound records an outbound ICMP connection // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode) { func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress) t.track(srcIP, dstIP, id, typecode, nftypes.Egress, size)
} }
} }
// TrackInbound records an inbound ICMP Echo Request // TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode) { func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress) t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, size)
} }
// track is the common implementation for tracking both inbound and outbound ICMP connections // track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) { func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id) // TODO: icmp doesn't need to extend the timeout
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists { if exists {
return return
} }
@ -113,7 +115,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code) t.sendStartEvent(direction, srcIP, dstIP, typ, code, size)
return return
} }
@ -138,7 +140,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8) bool { func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) { if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false return false
} }
@ -158,6 +160,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true return true
} }
@ -181,7 +184,8 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout)", &key) t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn)
} }
} }
@ -207,11 +211,15 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) {
DestIP: conn.DestIP, DestIP: conn.DestIP,
ICMPType: conn.ICMPType, ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode, ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
}) })
} }
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8) { func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, size int) {
t.flowLogger.StoreEvent(nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeStart, Type: nftypes.TypeStart,
Direction: direction, Direction: direction,
@ -220,5 +228,13 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
DestIP: dstIP, DestIP: dstIP,
ICMPType: typ, ICMPType: typ,
ICMPCode: code, ICMPCode: code,
}) }
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
} }

View File

@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
} }
}) })
@ -28,12 +28,12 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0) tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
} }
}) })
} }

View File

@ -88,6 +88,8 @@ const (
// TCPConnTrack represents a TCP connection state // TCPConnTrack represents a TCP connection state
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16
DestPort uint16
State TCPState State TCPState
established atomic.Bool established atomic.Bool
tombstone atomic.Bool tombstone atomic.Bool
@ -144,7 +146,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) { func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@ -161,6 +163,8 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress) t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock() conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true return key, true
} }
@ -168,21 +172,21 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
} }
// TrackOutbound records an outbound TCP connection // TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) { func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, size)
} }
} }
// TrackInbound processes an inbound TCP packet and updates connection state // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) { func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, size)
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) { func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists { if exists {
return return
} }
@ -193,9 +197,9 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
Direction: direction, Direction: direction,
SourceIP: srcIP, SourceIP: srcIP,
DestIP: dstIP, DestIP: dstIP,
},
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
},
} }
conn.established.Store(false) conn.established.Store(false)
@ -212,7 +216,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) bool { func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{ key := ConnKey{
SrcIP: dstIP, SrcIP: dstIP,
DstIP: srcIP, DstIP: srcIP,
@ -239,6 +243,7 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetEstablished(false) conn.SetEstablished(false)
conn.Unlock() conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key) t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn)
@ -427,7 +432,7 @@ func (t *TCPTracker) cleanup() {
// Return IPs to pool // Return IPs to pool
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", &key) t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
// event already handled by state change // event already handled by state change
if conn.State != TCPStateTimeWait { if conn.State != TCPStateTimeWait {
@ -472,5 +477,9 @@ func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack) {
DestIP: conn.DestIP, DestIP: conn.DestIP,
SourcePort: conn.SourcePort, SourcePort: conn.SourcePort,
DestPort: conn.DestPort, DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
}) })
} }

View File

@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
require.Equal(t, !tt.wantDrop, isValid, tt.desc) require.Equal(t, !tt.wantDrop, isValid, tt.desc)
}) })
} }
@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper() t.Helper()
// Send initial SYN // Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive SYN-ACK // Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK // Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Test data transfer // Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
require.True(t, valid, "Data should be allowed after handshake") require.True(t, valid, "Data should be allowed after handshake")
}, },
}, },
@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN // Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Receive ACK for FIN // Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "ACK for FIN should be allowed") require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side // Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "FIN should be allowed") require.True(t, valid, "FIN should be allowed")
// Send final ACK // Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}, },
}, },
{ {
@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST // Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection") require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets // Connection is logically dead but we don't enforce blocking subsequent packets
@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK // Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "Simultaneous FIN should be allowed") require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK // Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "Final ACKs should be allowed") require.True(t, valid, "Final ACKs should be allowed")
}, },
}, },
@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established", name: "RST in established",
setupState: func() { setupState: func() {
// Establish connection first // Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}, },
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
}, },
wantValid: true, wantValid: true,
desc: "Should accept RST for established connection", desc: "Should accept RST for established connection",
@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection", name: "RST without connection",
setupState: func() {}, setupState: func() {},
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
}, },
wantValid: false, wantValid: false,
desc: "Should reject RST without connection", desc: "Should reject RST without connection",
@ -228,12 +228,12 @@ func TestRSTHandling(t *testing.T) {
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper() t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
} }
func BenchmarkTCPTracker(b *testing.B) { func BenchmarkTCPTracker(b *testing.B) {
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} }
}) })
@ -259,12 +259,12 @@ func BenchmarkTCPTracker(b *testing.B) {
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
} }
}) })
@ -279,9 +279,9 @@ func BenchmarkTCPTracker(b *testing.B) {
i := 0 i := 0
for pb.Next() { for pb.Next() {
if i%2 == 0 { if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else { } else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
} }
i++ i++
} }
@ -299,7 +299,7 @@ func BenchmarkCleanup(b *testing.B) {
srcIP := netip.MustParseAddr("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
} }
// Wait for connections to expire // Wait for connections to expire

View File

@ -21,6 +21,8 @@ const (
// UDPConnTrack represents a UDP connection state // UDPConnTrack represents a UDP connection state
type UDPConnTrack struct { type UDPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16
DestPort uint16
} }
// UDPTracker manages UDP connection states // UDPTracker manages UDP connection states
@ -54,19 +56,19 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
} }
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, size)
} }
} }
// TrackInbound records an inbound UDP connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, size)
} }
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) (ConnKey, bool) { func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@ -80,6 +82,7 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists { if exists {
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true return key, true
} }
@ -87,8 +90,8 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction) { func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists { if exists {
return return
} }
@ -99,9 +102,9 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
Direction: direction, Direction: direction,
SourceIP: srcIP, SourceIP: srcIP,
DestIP: dstIP, DestIP: dstIP,
},
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
},
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
@ -114,7 +117,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) bool { func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := ConnKey{ key := ConnKey{
SrcIP: dstIP, SrcIP: dstIP,
DstIP: srcIP, DstIP: srcIP,
@ -131,6 +134,7 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true return true
} }
@ -155,7 +159,8 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout)", key) t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn) t.sendEvent(nftypes.TypeEnd, conn)
} }
} }
@ -201,5 +206,9 @@ func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack) {
DestIP: conn.DestIP, DestIP: conn.DestIP,
SourcePort: conn.SourcePort, SourcePort: conn.SourcePort,
DestPort: conn.DestPort, DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
}) })
} }

View File

@ -48,7 +48,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
// Verify connection was tracked // Verify connection was tracked
key := ConnKey{ key := ConnKey{
@ -76,7 +76,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
dstPort := uint16(53) dstPort := uint16(53)
// Track outbound connection // Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tests := []struct { tests := []struct {
name string name string
@ -148,7 +148,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 { if tt.sleep > 0 {
time.Sleep(tt.sleep) time.Sleep(tt.sleep)
} }
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }
@ -194,7 +194,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
} }
for _, conn := range connections { for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
} }
// Verify initial connections // Verify initial connections
@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
} }
}) })
@ -237,12 +237,12 @@ func BenchmarkUDPTracker(b *testing.B) {
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
} }
}) })
} }

View File

@ -15,13 +15,16 @@ import (
// handleICMP handles ICMP packets from the network stack // handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
flowID := uuid.New()
// Extract ICMP header to get type and code
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type()) icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code()) icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode) f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
@ -33,8 +36,6 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err != nil { if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err) f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
// This will make netstack reply on behalf of the original destination, that's ok for now // This will make netstack reply on behalf of the original destination, that's ok for now
return false return false
} }
@ -42,52 +43,36 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err) f.logger.Debug("Failed to close ICMP socket: %v", err)
} }
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}() }()
dstIP := f.determineDialAddr(id.LocalAddress) dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP} dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader()) fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice() payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response // For Echo Requests, send and handle response
switch icmpHdr.Type() { if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
case header.ICMPv4Echo: f.handleEchoResponse(icmpHdr, conn, id)
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id, flowID) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
default:
} }
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true return true
} }
f.logger.Trace("Forwarded ICMP packet %v type %v code %v", func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
epID(id), icmpHdr.Type(), icmpHdr.Code())
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID, flowID uuid.UUID) bool {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err) f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return true return
} }
response := make([]byte, f.endpoint.mtu) response := make([]byte, f.endpoint.mtu)
@ -96,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if !isTimeout(err) { if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err) f.logger.Error("Failed to read ICMP response: %v", err)
} }
return true return
} }
ipHdr := make([]byte, header.IPv4MinimumSize) ipHdr := make([]byte, header.IPv4MinimumSize)
@ -117,13 +102,11 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if err := f.InjectIncomingPacket(fullPacket); err != nil { if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err) f.logger.Error("Failed to inject ICMP response: %v", err)
return true return
} }
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v", f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code()) epID(id), icmpHdr.Type(), icmpHdr.Code())
return true
} }
// sendICMPEvent stores flow events for ICMP packets // sendICMPEvent stores flow events for ICMP packets
@ -138,5 +121,7 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
ICMPType: icmpType, ICMPType: icmpType,
ICMPCode: icmpCode, ICMPCode: icmpCode,
// TODO: get packets/bytes
}) })
} }

View File

@ -22,7 +22,14 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID() id := r.ID()
flowID := uuid.New() flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id)
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
@ -51,6 +58,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep) inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id)) f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID) go f.proxyTCP(id, inConn, outConn, ep, flowID)
@ -66,7 +74,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
} }
ep.Close() ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id) f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}() }()
// Create context for managing the proxy goroutines // Create context for managing the proxy goroutines
@ -98,17 +106,27 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
} }
} }
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) { func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: 6, Protocol: nftypes.TCP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()), SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort, SourcePort: id.LocalPort,
DestPort: id.RemotePort, DestPort: id.RemotePort,
}) }
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
} }

View File

@ -165,13 +165,19 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
} }
flowID := uuid.New() flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id)
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
// TODO: Send ICMP error message // TODO: Send ICMP error message
return return
} }
@ -184,7 +190,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return return
} }
@ -212,13 +217,14 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return return
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id)) f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep) go f.proxyUDP(connCtx, pConn, id, ep)
} }
@ -238,7 +244,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
delete(f.udpForwarder.conns, id) delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id) f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}() }()
errChan := make(chan error, 2) errChan := make(chan error, 2)
@ -265,8 +271,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
} }
// sendUDPEvent stores flow events for UDP connections, mirrors the TCP version // sendUDPEvent stores flow events for UDP connections, mirrors the TCP version
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) { func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
f.flowLogger.StoreEvent(nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
@ -276,7 +282,18 @@ func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.Tr
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort, SourcePort: id.LocalPort,
DestPort: id.RemotePort, DestPort: id.RemotePort,
}) }
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
} }
func (c *udpPacketConn) updateLastSeen() { func (c *udpPacketConn) updateLastSeen() {

View File

@ -281,7 +281,7 @@ func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder
} }
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool { func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP) allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
msg := "No existing connection found" msg := "No existing connection found"
if allowed { if allowed {
msg = m.buildConntrackStateMessage(d) msg = m.buildConntrackStateMessage(d)
@ -391,7 +391,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state // will create or update the connection state
dropped := m.processOutgoingHooks(packetData) dropped := m.processOutgoingHooks(packetData, 0)
if dropped { if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else { } else {

View File

@ -510,13 +510,13 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
} }
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool { func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData) return m.processOutgoingHooks(packetData, size)
} }
// DropIncoming filter incoming packets // DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool { func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData) return m.dropFilter(packetData, size)
} }
// UpdateLocalIPs updates the list of local IPs // UpdateLocalIPs updates the list of local IPs
@ -524,7 +524,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface) return m.localipmanager.UpdateLocalIPs(m.wgIface)
} }
func (m *Manager) processOutgoingHooks(packetData []byte) bool { func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@ -544,7 +544,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
// Track all protocols if stateful mode is enabled // Track all protocols if stateful mode is enabled
if m.stateful { if m.stateful {
m.trackOutbound(d, srcIP, dstIP) m.trackOutbound(d, srcIP, dstIP, size)
} }
// Process UDP hooks even if stateful mode is disabled // Process UDP hooks even if stateful mode is disabled
@ -593,29 +593,29 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr) { func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort)) m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
} }
} }
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr) { func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort)) m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
} }
} }
@ -637,7 +637,7 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP netip.Addr, packetData []byte)
// dropFilter implements filtering logic for incoming packets. // dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte) bool { func (m *Manager) dropFilter(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@ -653,12 +653,12 @@ func (m *Manager) dropFilter(packetData []byte) bool {
// For all inbound traffic, first check if it matches a tracked connection. // For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked. // This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false return false
} }
if m.localipmanager.IsLocalIP(dstIP) { if m.localipmanager.IsLocalIP(dstIP) {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData) return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
} }
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData) return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
@ -666,7 +666,7 @@ func (m *Manager) dropFilter(packetData []byte) bool {
// handleLocalTraffic handles local traffic. // handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked { if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked {
_, pnum := getProtocolFromPacket(d) _, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
@ -685,6 +685,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
// TODO: icmp type/code // TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
}) })
return true return true
} }
@ -695,7 +697,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
} }
// track inbound packets to get the correct direction and session id for flows // track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP) m.trackInbound(d, srcIP, dstIP, size)
return false return false
} }
@ -802,7 +804,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
return true return true
} }
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) bool { func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound( return m.tcpTracker.IsValidInbound(
@ -811,6 +813,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr)
uint16(d.tcp.SrcPort), uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort), uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp), getTCPFlags(&d.tcp),
size,
) )
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
@ -819,6 +822,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr)
dstIP, dstIP,
uint16(d.udp.SrcPort), uint16(d.udp.SrcPort),
uint16(d.udp.DstPort), uint16(d.udp.DstPort),
size,
) )
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
@ -827,6 +831,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr)
dstIP, dstIP,
d.icmp4.Id, d.icmp4.Id,
d.icmp4.TypeCode.Type(), d.icmp4.TypeCode.Type(),
size,
) )
// TODO: ICMPv6 // TODO: ICMPv6

View File

@ -193,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection // For stateful scenarios, establish the connection
if sc.stateful { if sc.stateful {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
// Measure inbound packet processing // Measure inbound packet processing
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@ -230,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i], outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP) uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
// Test packet // Test packet
@ -238,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection // First establish our test connection
manager.processOutgoingHooks(testOut) manager.processOutgoingHooks(testOut, 0)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(testIn) manager.dropFilter(testIn, 0)
} }
}) })
} }
@ -278,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established { if sc.established {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@ -477,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections // For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") || if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") { (strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
// For TCP post-handshake, simulate full handshake // For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" { if sc.state == "post_handshake" {
// SYN // SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@ -624,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN // Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
// Prepare test packets simulating bidirectional traffic // Prepare test packets simulating bidirectional traffic
@ -655,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic // Simulate bidirectional traffic
// First outbound data // First outbound data
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring // Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx]) manager.dropFilter(inPackets[connIdx], 0)
} }
}) })
} }
@ -761,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Connection establishment // Connection establishment
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck) manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack, 0)
// Data transfer // Data transfer
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response) manager.dropFilter(p.response, 0)
// Connection teardown // Connection teardown
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer) manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer) manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient, 0)
} }
}) })
} }
@ -826,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ { for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
// Pre-generate test packets // Pre-generate test packets
@ -856,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++ counter++
// Simulate bidirectional traffic // Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx]) manager.dropFilter(inPackets[connIdx], 0)
} }
}) })
}) })
@ -950,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Full connection lifecycle // Full connection lifecycle
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck) manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response) manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer) manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer) manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient, 0)
} }
}) })
}) })

View File

@ -192,7 +192,7 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443) packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet) isDropped := manager.DropIncoming(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist") require.True(t, isDropped, "Packet should be dropped when no rules exist")
}) })
@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
}) })
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet) isDropped := manager.DropIncoming(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped) require.Equal(t, tc.shouldBeBlocked, isDropped)
}) })
} }

View File

@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes()) { if m.dropFilter(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test hook gets called // Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes()) result := manager.processOutgoingHooks(buf.Bytes(), 0)
require.True(t, result) require.True(t, result)
require.True(t, hookCalled) require.True(t, hookCalled)
@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4) err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err) require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes()) result = manager.processOutgoingHooks(buf.Bytes(), 0)
require.False(t, result) require.False(t, result)
} }
@ -569,7 +569,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Process outbound packet and verify connection tracking // Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes()) drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped") require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked // Verify connection was tracked
@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints { for _, cp := range checkPoints {
time.Sleep(cp.sleep) time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes()) drop = manager.dropFilter(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description) require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
// Create a new outbound connection for invalid tests // Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes()) drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped") require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases { for _, tc := range invalidCases {
@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify the invalid packet is dropped // Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes()) drop = manager.dropFilter(testBuf.Bytes(), 0)
require.True(t, drop, tc.description) require.True(t, drop, tc.description)
}) })
} }

View File

@ -11,10 +11,10 @@ import (
// PacketFilter interface for firewall abilities // PacketFilter interface for firewall abilities
type PacketFilter interface { type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations // DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte) bool DropOutgoing(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host // DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool DropIncoming(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
@ -58,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) { if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...) bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...) sizes = append(sizes[:i], sizes[i+1:]...)
n-- n--
@ -82,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs)) filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0 dropped := 0
for _, buf := range bufs { for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:]) { if !filter.DropIncoming(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf) filteredBufs = append(filteredBufs, buf)
dropped++ dropped++
} }

View File

@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil) tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true) filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter
@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil return 1, nil
}) })
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter

View File

@ -50,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
} }
// DropIncoming mocks base method. // DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0) ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropIncoming indicates an expected call of DropIncoming. // DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
} }
// DropOutgoing mocks base method. // DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0) ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropOutgoing indicates an expected call of DropOutgoing. // DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
} }
// RemovePacketHook mocks base method. // RemovePacketHook mocks base method.

View File

@ -458,7 +458,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes() packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet) packetfilter.EXPECT().SetNetwork(ipNet)

View File

@ -2,6 +2,7 @@ package types
import ( import (
"net/netip" "net/netip"
"strconv"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@ -27,8 +28,10 @@ func (p Protocol) String() string {
return "TCP" return "TCP"
case 17: case 17:
return "UDP" return "UDP"
case 132:
return "SCTP"
default: default:
return "unknown" return strconv.FormatUint(uint64(p), 10)
} }
} }