diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index b87ec6fd2..3de0bb3f4 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -13,13 +13,15 @@ import ( // BaseConnTrack provides common fields and locking for all connection types type BaseConnTrack struct { - FlowId uuid.UUID - Direction nftypes.Direction - SourceIP netip.Addr - DestIP netip.Addr - SourcePort uint16 - DestPort uint16 - lastSeen atomic.Int64 + FlowId uuid.UUID + Direction nftypes.Direction + SourceIP netip.Addr + DestIP netip.Addr + lastSeen atomic.Int64 + PacketsTx atomic.Uint64 + PacketsRx atomic.Uint64 + BytesTx atomic.Uint64 + BytesRx atomic.Uint64 } // these small methods will be inlined by the compiler @@ -29,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() { b.lastSeen.Store(time.Now().UnixNano()) } +// UpdateCounters safely updates the packet and byte counters +func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) { + if direction == nftypes.Egress { + b.PacketsTx.Add(1) + b.BytesTx.Add(uint64(bytes)) + } else { + b.PacketsRx.Add(1) + b.BytesRx.Add(uint64(bytes)) + } +} + // GetLastSeen safely gets the last seen timestamp func (b *BaseConnTrack) GetLastSeen() time.Time { return time.Unix(0, b.lastSeen.Load()) diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index e83b6aa85..ca1136e6f 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -32,11 +32,11 @@ func BenchmarkMemoryPressure(b *testing.B) { for i := 0; i < b.N; i++ { srcIdx := i % len(srcIPs) dstIdx := (i + 1) % len(dstIPs) - tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0) // Simulate some valid inbound packets if i%3 == 0 { - tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0) } } }) @@ -57,11 +57,11 @@ func BenchmarkMemoryPressure(b *testing.B) { for i := 0; i < b.N; i++ { srcIdx := i % len(srcIPs) dstIdx := (i + 1) % len(dstIPs) - tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0) // Simulate some valid inbound packets if i%3 == 0 { - tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0) } } }) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 7837c1a23..730edd6ce 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -68,7 +68,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty return tracker } -func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16) (ICMPConnKey, bool) { +func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) { key := ICMPConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -81,6 +81,7 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint if exists { conn.UpdateLastSeen() + conn.UpdateCounters(direction, size) return key, true } @@ -89,21 +90,22 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint } // TrackOutbound records an outbound ICMP connection -func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode) { - if _, exists := t.updateIfExists(dstIP, srcIP, id); !exists { +func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, size int) { + if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists { // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, id, typecode, nftypes.Egress) + t.track(srcIP, dstIP, id, typecode, nftypes.Egress, size) } } // TrackInbound records an inbound ICMP Echo Request -func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode) { - t.track(srcIP, dstIP, id, typecode, nftypes.Ingress) +func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) { + t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, size) } // track is the common implementation for tracking both inbound and outbound ICMP connections -func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) { - key, exists := t.updateIfExists(srcIP, dstIP, id) +func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, size int) { + // TODO: icmp doesn't need to extend the timeout + key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size) if exists { return } @@ -113,7 +115,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec // non echo requests don't need tracking if typ != uint8(layers.ICMPv4TypeEchoRequest) { t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) - t.sendStartEvent(direction, srcIP, dstIP, typ, code) + t.sendStartEvent(direction, srcIP, dstIP, typ, code, size) return } @@ -138,7 +140,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request -func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8) bool { +func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool { if icmpType != uint8(layers.ICMPv4TypeEchoReply) { return false } @@ -158,6 +160,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint } conn.UpdateLastSeen() + conn.UpdateCounters(nftypes.Ingress, size) return true } @@ -181,7 +184,8 @@ func (t *ICMPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Debug("Removed ICMP connection %s (timeout)", &key) + t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", + key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.sendEvent(nftypes.TypeEnd, conn) } } @@ -207,11 +211,15 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) { DestIP: conn.DestIP, ICMPType: conn.ICMPType, ICMPCode: conn.ICMPCode, + RxPackets: conn.PacketsRx.Load(), + TxPackets: conn.PacketsTx.Load(), + RxBytes: conn.BytesRx.Load(), + TxBytes: conn.BytesTx.Load(), }) } -func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8) { - t.flowLogger.StoreEvent(nftypes.EventFields{ +func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, size int) { + fields := nftypes.EventFields{ FlowID: uuid.New(), Type: nftypes.TypeStart, Direction: direction, @@ -220,5 +228,13 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad DestIP: dstIP, ICMPType: typ, ICMPCode: code, - }) + } + if direction == nftypes.Ingress { + fields.RxPackets = 1 + fields.RxBytes = uint64(size) + } else { + fields.TxPackets = 1 + fields.TxBytes = uint64(size) + } + t.flowLogger.StoreEvent(fields) } diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index 01040feb2..5a7b36a36 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0) } }) @@ -28,12 +28,12 @@ func BenchmarkICMPTracker(b *testing.B) { // Pre-populate some connections for i := 0; i < 1000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0) } b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0) + tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0) } }) } diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 892ae6bcb..96bbf0220 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -88,6 +88,8 @@ const ( // TCPConnTrack represents a TCP connection state type TCPConnTrack struct { BaseConnTrack + SourcePort uint16 + DestPort uint16 State TCPState established atomic.Bool tombstone atomic.Bool @@ -144,7 +146,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) { +func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -161,6 +163,8 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort t.updateState(key, conn, flags, conn.Direction == nftypes.Egress) conn.Unlock() + conn.UpdateCounters(direction, size) + return key, true } @@ -168,34 +172,34 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort } // TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists { +func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) { + if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists { // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress) + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, size) } } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) { - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress) +func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, size) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags) +func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) { + key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) if exists { return } conn := &TCPConnTrack{ BaseConnTrack: BaseConnTrack{ - FlowId: uuid.New(), - Direction: direction, - SourceIP: srcIP, - DestIP: dstIP, - SourcePort: srcPort, - DestPort: dstPort, + FlowId: uuid.New(), + Direction: direction, + SourceIP: srcIP, + DestIP: dstIP, }, + SourcePort: srcPort, + DestPort: dstPort, } conn.established.Store(false) @@ -212,7 +216,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d } // IsValidInbound checks if an inbound TCP packet matches a tracked connection -func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) bool { +func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool { key := ConnKey{ SrcIP: dstIP, DstIP: srcIP, @@ -239,6 +243,7 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort conn.State = TCPStateClosed conn.SetEstablished(false) conn.Unlock() + conn.UpdateCounters(nftypes.Ingress, size) t.logger.Trace("TCP connection reset: %s", key) t.sendEvent(nftypes.TypeEnd, conn) @@ -427,7 +432,7 @@ func (t *TCPTracker) cleanup() { // Return IPs to pool delete(t.connections, key) - t.logger.Trace("Cleaned up timed-out TCP connection %s", &key) + t.logger.Trace("Cleaned up timed-out TCP connection %s", key) // event already handled by state change if conn.State != TCPStateTimeWait { @@ -472,5 +477,9 @@ func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack) { DestIP: conn.DestIP, SourcePort: conn.SourcePort, DestPort: conn.DestPort, + RxPackets: conn.PacketsRx.Load(), + TxPackets: conn.PacketsTx.Load(), + RxBytes: conn.BytesRx.Load(), + TxBytes: conn.BytesTx.Load(), }) } diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 122deae1e..96558583d 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) + isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0) require.Equal(t, !tt.wantDrop, isValid, tt.desc) }) } @@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) { t.Helper() // Send initial SYN - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) // Receive SYN-ACK - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0) require.True(t, valid, "SYN-ACK should be allowed") // Send ACK - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) // Test data transfer - valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0) require.True(t, valid, "Data should be allowed after handshake") }, }, @@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) { establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) // Send FIN - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // Receive ACK for FIN - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) require.True(t, valid, "ACK for FIN should be allowed") // Receive FIN from other side - valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) require.True(t, valid, "FIN should be allowed") // Send final ACK - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) }, }, { @@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) { establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) // Receive RST - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) require.True(t, valid, "RST should be allowed for established connection") // Connection is logically dead but we don't enforce blocking subsequent packets @@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) { establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) // Both sides send FIN+ACK - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) require.True(t, valid, "Simultaneous FIN should be allowed") // Both sides send final ACK - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) - valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) require.True(t, valid, "Final ACKs should be allowed") }, }, @@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) { name: "RST in established", setupState: func() { // Establish connection first - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) - tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) }, sendRST: func() { - tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) }, wantValid: true, desc: "Should accept RST for established connection", @@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) { name: "RST without connection", setupState: func() {}, sendRST: func() { - tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) }, wantValid: false, desc: "Should reject RST without connection", @@ -228,12 +228,12 @@ func TestRSTHandling(t *testing.T) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { t.Helper() - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0) require.True(t, valid, "SYN-ACK should be allowed") - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) } func BenchmarkTCPTracker(b *testing.B) { @@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0) } }) @@ -259,12 +259,12 @@ func BenchmarkTCPTracker(b *testing.B) { // Pre-populate some connections for i := 0; i < 1000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0) } b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0) } }) @@ -279,9 +279,9 @@ func BenchmarkTCPTracker(b *testing.B) { i := 0 for pb.Next() { if i%2 == 0 { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0) } else { - tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0) } i++ } @@ -299,7 +299,7 @@ func BenchmarkCleanup(b *testing.B) { srcIP := netip.MustParseAddr("192.168.1.1") dstIP := netip.MustParseAddr("192.168.1.2") for i := 0; i < 10000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0) } // Wait for connections to expire diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 465e400be..351b3c93a 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -21,6 +21,8 @@ const ( // UDPConnTrack represents a UDP connection state type UDPConnTrack struct { BaseConnTrack + SourcePort uint16 + DestPort uint16 } // UDPTracker manages UDP connection states @@ -54,19 +56,19 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp } // TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists { +func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { + if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress) + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, size) } } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) { - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress) +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, size) } -func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) (ConnKey, bool) { +func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -80,6 +82,7 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort if exists { conn.UpdateLastSeen() + conn.UpdateCounters(direction, size) return key, true } @@ -87,21 +90,21 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort) +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) { + key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) if exists { return } conn := &UDPConnTrack{ BaseConnTrack: BaseConnTrack{ - FlowId: uuid.New(), - Direction: direction, - SourceIP: srcIP, - DestIP: dstIP, - SourcePort: srcPort, - DestPort: dstPort, + FlowId: uuid.New(), + Direction: direction, + SourceIP: srcIP, + DestIP: dstIP, }, + SourcePort: srcPort, + DestPort: dstPort, } conn.UpdateLastSeen() @@ -114,7 +117,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d } // IsValidInbound checks if an inbound packet matches a tracked connection -func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) bool { +func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool { key := ConnKey{ SrcIP: dstIP, DstIP: srcIP, @@ -131,6 +134,7 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort } conn.UpdateLastSeen() + conn.UpdateCounters(nftypes.Ingress, size) return true } @@ -155,7 +159,8 @@ func (t *UDPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Trace("Removed UDP connection %s (timeout)", key) + t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", + key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.sendEvent(nftypes.TypeEnd, conn) } } @@ -201,5 +206,9 @@ func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack) { DestIP: conn.DestIP, SourcePort: conn.SourcePort, DestPort: conn.DestPort, + RxPackets: conn.PacketsRx.Load(), + TxPackets: conn.PacketsTx.Load(), + RxBytes: conn.BytesRx.Load(), + TxBytes: conn.BytesTx.Load(), }) } diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index db7fa0f51..14b912908 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -48,7 +48,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { srcPort := uint16(12345) dstPort := uint16(53) - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0) // Verify connection was tracked key := ConnKey{ @@ -76,7 +76,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) { dstPort := uint16(53) // Track outbound connection - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0) tests := []struct { name string @@ -148,7 +148,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) { if tt.sleep > 0 { time.Sleep(tt.sleep) } - got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) + got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0) assert.Equal(t, tt.want, got) }) } @@ -194,7 +194,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { } for _, conn := range connections { - tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) + tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0) } // Verify initial connections @@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0) } }) @@ -237,12 +237,12 @@ func BenchmarkUDPTracker(b *testing.B) { // Pre-populate some connections for i := 0; i < 1000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0) } b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0) } }) } diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 3d7e8857b..833854a21 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -15,13 +15,16 @@ import ( // handleICMP handles ICMP packets from the network stack func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { - flowID := uuid.New() - - // Extract ICMP header to get type and code icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) icmpType := uint8(icmpHdr.Type()) icmpCode := uint8(icmpHdr.Code()) + if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply { + // dont process our own replies + return true + } + + flowID := uuid.New() f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) @@ -33,8 +36,6 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf if err != nil { f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err) - f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode) - // This will make netstack reply on behalf of the original destination, that's ok for now return false } @@ -42,30 +43,15 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf if err := conn.Close(); err != nil { f.logger.Debug("Failed to close ICMP socket: %v", err) } - - f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode) }() dstIP := f.determineDialAddr(id.LocalAddress) dst := &net.IPAddr{IP: dstIP} - // Get the complete ICMP message (header + data) fullPacket := stack.PayloadSince(pkt.TransportHeader()) payload := fullPacket.AsSlice() - // For Echo Requests, send and handle response - switch icmpHdr.Type() { - case header.ICMPv4Echo: - return f.handleEchoResponse(icmpHdr, payload, dst, conn, id, flowID) - case header.ICMPv4EchoReply: - // dont process our own replies - return true - default: - } - - // For other ICMP types (Time Exceeded, Destination Unreachable, etc) - _, err = conn.WriteTo(payload, dst) - if err != nil { + if _, err = conn.WriteTo(payload, dst); err != nil { f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err) return true } @@ -73,21 +59,20 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf f.logger.Trace("Forwarded ICMP packet %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) + // For Echo Requests, send and handle response + if header.ICMPv4Type(icmpType) == header.ICMPv4Echo { + f.handleEchoResponse(icmpHdr, conn, id) + f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode) + } + + // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing return true } -func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID, flowID uuid.UUID) bool { - if _, err := conn.WriteTo(payload, dst); err != nil { - f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err) - return true - } - - f.logger.Trace("Forwarded ICMP packet %v type %v code %v", - epID(id), icmpHdr.Type(), icmpHdr.Code()) - +func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { f.logger.Error("Failed to set read deadline for ICMP response: %v", err) - return true + return } response := make([]byte, f.endpoint.mtu) @@ -96,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds if !isTimeout(err) { f.logger.Error("Failed to read ICMP response: %v", err) } - return true + return } ipHdr := make([]byte, header.IPv4MinimumSize) @@ -117,13 +102,11 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds if err := f.InjectIncomingPacket(fullPacket); err != nil { f.logger.Error("Failed to inject ICMP response: %v", err) - return true + return } f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v", epID(id), icmpHdr.Type(), icmpHdr.Code()) - - return true } // sendICMPEvent stores flow events for ICMP packets @@ -138,5 +121,7 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), ICMPType: icmpType, ICMPCode: icmpCode, + + // TODO: get packets/bytes }) } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 44aafa989..c3e1eca80 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -22,7 +22,14 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { id := r.ID() flowID := uuid.New() - f.sendTCPEvent(nftypes.TypeStart, flowID, id) + + f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil) + var success bool + defer func() { + if !success { + f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil) + } + }() dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) @@ -51,6 +58,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { inConn := gonet.NewTCPConn(&wq, ep) + success = true f.logger.Trace("forwarder: established TCP connection %v", epID(id)) go f.proxyTCP(id, inConn, outConn, ep, flowID) @@ -66,7 +74,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn } ep.Close() - f.sendTCPEvent(nftypes.TypeEnd, flowID, id) + f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep) }() // Create context for managing the proxy goroutines @@ -98,17 +106,27 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn } } -func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) { - - f.flowLogger.StoreEvent(nftypes.EventFields{ +func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, - Protocol: 6, + Protocol: nftypes.TCP, // TODO: handle ipv6 SourceIP: netip.AddrFrom4(id.LocalAddress.As4()), DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), SourcePort: id.LocalPort, DestPort: id.RemotePort, - }) + } + + if ep != nil { + if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { + // fields are flipped since this is the in conn + // TODO: get bytes + fields.RxPackets = tcpStats.SegmentsSent.Value() + fields.TxPackets = tcpStats.SegmentsReceived.Value() + } + } + + f.flowLogger.StoreEvent(fields) } diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index db8aa1a2f..20e1ee3a7 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -165,13 +165,19 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { } flowID := uuid.New() - f.sendUDPEvent(nftypes.TypeStart, flowID, id) + + f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil) + var success bool + defer func() { + if !success { + f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil) + } + }() dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) if err != nil { f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err) - f.sendUDPEvent(nftypes.TypeEnd, flowID, id) // TODO: Send ICMP error message return } @@ -184,7 +190,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { if err := outConn.Close(); err != nil { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) } - f.sendUDPEvent(nftypes.TypeEnd, flowID, id) return } @@ -212,13 +217,14 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) } - f.sendUDPEvent(nftypes.TypeEnd, flowID, id) return } f.udpForwarder.conns[id] = pConn f.udpForwarder.Unlock() + success = true f.logger.Trace("forwarder: established UDP connection %v", epID(id)) + go f.proxyUDP(connCtx, pConn, id, ep) } @@ -238,7 +244,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack delete(f.udpForwarder.conns, id) f.udpForwarder.Unlock() - f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id) + f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep) }() errChan := make(chan error, 2) @@ -265,8 +271,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack } // sendUDPEvent stores flow events for UDP connections, mirrors the TCP version -func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) { - f.flowLogger.StoreEvent(nftypes.EventFields{ +func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { + fields := nftypes.EventFields{ FlowID: flowID, Type: typ, Direction: nftypes.Ingress, @@ -276,7 +282,18 @@ func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.Tr DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), SourcePort: id.LocalPort, DestPort: id.RemotePort, - }) + } + + if ep != nil { + if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { + // fields are flipped since this is the in conn + // TODO: get bytes + fields.RxPackets = tcpStats.PacketsSent.Value() + fields.TxPackets = tcpStats.PacketsReceived.Value() + } + } + + f.flowLogger.StoreEvent(fields) } func (c *udpPacketConn) updateLastSeen() { diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index a980cda29..cc5edc554 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -281,7 +281,7 @@ func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder } func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool { - allowed := m.isValidTrackedConnection(d, srcIP, dstIP) + allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0) msg := "No existing connection found" if allowed { msg = m.buildConntrackStateMessage(d) @@ -391,7 +391,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { // will create or update the connection state - dropped := m.processOutgoingHooks(packetData) + dropped := m.processOutgoingHooks(packetData, 0) if dropped { trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) } else { diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index e663966ab..73b963d06 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -510,13 +510,13 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { } // DropOutgoing filter outgoing packets -func (m *Manager) DropOutgoing(packetData []byte) bool { - return m.processOutgoingHooks(packetData) +func (m *Manager) DropOutgoing(packetData []byte, size int) bool { + return m.processOutgoingHooks(packetData, size) } // DropIncoming filter incoming packets -func (m *Manager) DropIncoming(packetData []byte) bool { - return m.dropFilter(packetData) +func (m *Manager) DropIncoming(packetData []byte, size int) bool { + return m.dropFilter(packetData, size) } // UpdateLocalIPs updates the list of local IPs @@ -524,7 +524,7 @@ func (m *Manager) UpdateLocalIPs() error { return m.localipmanager.UpdateLocalIPs(m.wgIface) } -func (m *Manager) processOutgoingHooks(packetData []byte) bool { +func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool { d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) @@ -544,7 +544,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { // Track all protocols if stateful mode is enabled if m.stateful { - m.trackOutbound(d, srcIP, dstIP) + m.trackOutbound(d, srcIP, dstIP, size) } // Process UDP hooks even if stateful mode is disabled @@ -593,29 +593,29 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr) { +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort)) + m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) + m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) case layers.LayerTypeICMPv4: - m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode) + m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size) } } -func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr) { +func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort)) + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) case layers.LayerTypeICMPv4: - m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode) + m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size) } } @@ -637,7 +637,7 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP netip.Addr, packetData []byte) // dropFilter implements filtering logic for incoming packets. // If it returns true, the packet should be dropped. -func (m *Manager) dropFilter(packetData []byte) bool { +func (m *Manager) dropFilter(packetData []byte, size int) bool { d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) @@ -653,12 +653,12 @@ func (m *Manager) dropFilter(packetData []byte) bool { // For all inbound traffic, first check if it matches a tracked connection. // This must happen before any other filtering because the packets are statefully tracked. - if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { + if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) { return false } if m.localipmanager.IsLocalIP(dstIP) { - return m.handleLocalTraffic(d, srcIP, dstIP, packetData) + return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size) } return m.handleRoutedTraffic(d, srcIP, dstIP, packetData) @@ -666,7 +666,7 @@ func (m *Manager) dropFilter(packetData []byte) bool { // handleLocalTraffic handles local traffic. // If it returns true, the packet should be dropped. -func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool { +func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked { _, pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) @@ -685,6 +685,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet SourcePort: srcPort, DestPort: dstPort, // TODO: icmp type/code + RxPackets: 1, + RxBytes: uint64(size), }) return true } @@ -695,7 +697,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet } // track inbound packets to get the correct direction and session id for flows - m.trackInbound(d, srcIP, dstIP) + m.trackInbound(d, srcIP, dstIP, size) return false } @@ -802,7 +804,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { return true } -func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) bool { +func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool { switch d.decoded[1] { case layers.LayerTypeTCP: return m.tcpTracker.IsValidInbound( @@ -811,6 +813,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), getTCPFlags(&d.tcp), + size, ) case layers.LayerTypeUDP: @@ -819,6 +822,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), + size, ) case layers.LayerTypeICMPv4: @@ -827,6 +831,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) dstIP, d.icmp4.Id, d.icmp4.TypeCode.Type(), + size, ) // TODO: ICMPv6 diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 9b06b2803..b43ac2b16 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -193,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) { // For stateful scenarios, establish the connection if sc.stateful { - manager.processOutgoingHooks(outbound) + manager.processOutgoingHooks(outbound, 0) } // Measure inbound packet processing b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound) + manager.dropFilter(inbound, 0) } }) } @@ -230,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) { for i := 0; i < count; i++ { outbound := generatePacket(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, layers.IPProtocolTCP) - manager.processOutgoingHooks(outbound) + manager.processOutgoingHooks(outbound, 0) } // Test packet @@ -238,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) { testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) // First establish our test connection - manager.processOutgoingHooks(testOut) + manager.processOutgoingHooks(testOut, 0) b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(testIn) + manager.dropFilter(testIn, 0) } }) } @@ -278,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) if sc.established { - manager.processOutgoingHooks(outbound) + manager.processOutgoingHooks(outbound, 0) } b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound) + manager.dropFilter(inbound, 0) } }) } @@ -477,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { // For stateful cases and established connections if !strings.Contains(sc.name, "allow_non_wg") || (strings.Contains(sc.state, "established") || sc.state == "post_handshake") { - manager.processOutgoingHooks(outbound) + manager.processOutgoingHooks(outbound, 0) // For TCP post-handshake, simulate full handshake if sc.state == "post_handshake" { // SYN syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) - manager.processOutgoingHooks(syn) + manager.processOutgoingHooks(syn, 0) // SYN-ACK synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack) + manager.dropFilter(synack, 0) // ACK ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) - manager.processOutgoingHooks(ack) + manager.processOutgoingHooks(ack, 0) } } b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound) + manager.dropFilter(inbound, 0) } }) } @@ -624,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) { // Initial SYN syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, uint16(conntrack.TCPSyn)) - manager.processOutgoingHooks(syn) + manager.processOutgoingHooks(syn, 0) // SYN-ACK synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack) + manager.dropFilter(synack, 0) // ACK ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, uint16(conntrack.TCPAck)) - manager.processOutgoingHooks(ack) + manager.processOutgoingHooks(ack, 0) } // Prepare test packets simulating bidirectional traffic @@ -655,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) { // Simulate bidirectional traffic // First outbound data - manager.processOutgoingHooks(outPackets[connIdx]) + manager.processOutgoingHooks(outPackets[connIdx], 0) // Then inbound response - this is what we're actually measuring - manager.dropFilter(inPackets[connIdx]) + manager.dropFilter(inPackets[connIdx], 0) } }) } @@ -761,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) { p := patterns[connIdx] // Connection establishment - manager.processOutgoingHooks(p.syn) - manager.dropFilter(p.synAck) - manager.processOutgoingHooks(p.ack) + manager.processOutgoingHooks(p.syn, 0) + manager.dropFilter(p.synAck, 0) + manager.processOutgoingHooks(p.ack, 0) // Data transfer - manager.processOutgoingHooks(p.request) - manager.dropFilter(p.response) + manager.processOutgoingHooks(p.request, 0) + manager.dropFilter(p.response, 0) // Connection teardown - manager.processOutgoingHooks(p.finClient) - manager.dropFilter(p.ackServer) - manager.dropFilter(p.finServer) - manager.processOutgoingHooks(p.ackClient) + manager.processOutgoingHooks(p.finClient, 0) + manager.dropFilter(p.ackServer, 0) + manager.dropFilter(p.finServer, 0) + manager.processOutgoingHooks(p.ackClient, 0) } }) } @@ -826,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { for i := 0; i < sc.connCount; i++ { syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, uint16(conntrack.TCPSyn)) - manager.processOutgoingHooks(syn) + manager.processOutgoingHooks(syn, 0) synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack) + manager.dropFilter(synack, 0) ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, uint16(conntrack.TCPAck)) - manager.processOutgoingHooks(ack) + manager.processOutgoingHooks(ack, 0) } // Pre-generate test packets @@ -856,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { counter++ // Simulate bidirectional traffic - manager.processOutgoingHooks(outPackets[connIdx]) - manager.dropFilter(inPackets[connIdx]) + manager.processOutgoingHooks(outPackets[connIdx], 0) + manager.dropFilter(inPackets[connIdx], 0) } }) }) @@ -950,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { p := patterns[connIdx] // Full connection lifecycle - manager.processOutgoingHooks(p.syn) - manager.dropFilter(p.synAck) - manager.processOutgoingHooks(p.ack) + manager.processOutgoingHooks(p.syn, 0) + manager.dropFilter(p.synAck, 0) + manager.processOutgoingHooks(p.ack, 0) - manager.processOutgoingHooks(p.request) - manager.dropFilter(p.response) + manager.processOutgoingHooks(p.request, 0) + manager.dropFilter(p.response, 0) - manager.processOutgoingHooks(p.finClient) - manager.dropFilter(p.ackServer) - manager.dropFilter(p.finServer) - manager.processOutgoingHooks(p.ackClient) + manager.processOutgoingHooks(p.finClient, 0) + manager.dropFilter(p.ackServer, 0) + manager.dropFilter(p.finServer, 0) + manager.processOutgoingHooks(p.ackClient, 0) } }) }) diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index 7005d501c..3a97506f1 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -192,7 +192,7 @@ func TestPeerACLFiltering(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) { packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443) - isDropped := manager.DropIncoming(packet) + isDropped := manager.DropIncoming(packet, 0) require.True(t, isDropped, "Packet should be dropped when no rules exist") }) @@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) { }) packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) - isDropped := manager.DropIncoming(packet) + isDropped := manager.DropIncoming(packet, 0) require.Equal(t, tc.shouldBeBlocked, isDropped) }) } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 429265794..1db572618 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if m.dropFilter(buf.Bytes()) { + if m.dropFilter(buf.Bytes(), 0) { t.Errorf("expected packet to be accepted") return } @@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) { require.NoError(t, err) // Test hook gets called - result := manager.processOutgoingHooks(buf.Bytes()) + result := manager.processOutgoingHooks(buf.Bytes(), 0) require.True(t, result) require.True(t, hookCalled) @@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) { err = gopacket.SerializeLayers(buf, opts, ipv4) require.NoError(t, err) - result = manager.processOutgoingHooks(buf.Bytes()) + result = manager.processOutgoingHooks(buf.Bytes(), 0) require.False(t, result) } @@ -569,7 +569,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { require.NoError(t, err) // Process outbound packet and verify connection tracking - drop := manager.DropOutgoing(outboundBuf.Bytes()) + drop := manager.DropOutgoing(outboundBuf.Bytes(), 0) require.False(t, drop, "Initial outbound packet should not be dropped") // Verify connection was tracked @@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { for _, cp := range checkPoints { time.Sleep(cp.sleep) - drop = manager.dropFilter(inboundBuf.Bytes()) + drop = manager.dropFilter(inboundBuf.Bytes(), 0) require.Equal(t, cp.shouldAllow, !drop, cp.description) // If the connection should still be valid, verify it exists @@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { } // Create a new outbound connection for invalid tests - drop = manager.processOutgoingHooks(outboundBuf.Bytes()) + drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0) require.False(t, drop, "Second outbound packet should not be dropped") for _, tc := range invalidCases { @@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { require.NoError(t, err) // Verify the invalid packet is dropped - drop = manager.dropFilter(testBuf.Bytes()) + drop = manager.dropFilter(testBuf.Bytes(), 0) require.True(t, drop, tc.description) }) } diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index f21804683..c9b7e2448 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -11,10 +11,10 @@ import ( // PacketFilter interface for firewall abilities type PacketFilter interface { // DropOutgoing filter outgoing packets from host to external destinations - DropOutgoing(packetData []byte) bool + DropOutgoing(packetData []byte, size int) bool // DropIncoming filter incoming packets from external sources to host - DropIncoming(packetData []byte) bool + DropIncoming(packetData []byte, size int) bool // AddUDPPacketHook calls hook when UDP packet from given direction matched // @@ -58,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er } for i := 0; i < n; i++ { - if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) { + if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) { bufs = append(bufs[:i], bufs[i+1:]...) sizes = append(sizes[:i], sizes[i+1:]...) n-- @@ -82,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { filteredBufs := make([][]byte, 0, len(bufs)) dropped := 0 for _, buf := range bufs { - if !filter.DropIncoming(buf[offset:]) { + if !filter.DropIncoming(buf[offset:], len(buf)) { filteredBufs = append(filteredBufs, buf) dropped++ } diff --git a/client/iface/device/device_filter_test.go b/client/iface/device/device_filter_test.go index d3278b918..c90269e82 100644 --- a/client/iface/device/device_filter_test.go +++ b/client/iface/device/device_filter_test.go @@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) { tun.EXPECT().Write(mockBufs, 0).Return(0, nil) filter := mocks.NewMockPacketFilter(ctrl) - filter.EXPECT().DropIncoming(gomock.Any()).Return(true) + filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true) wrapped := newDeviceFilter(tun) wrapped.filter = filter @@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) { return 1, nil }) filter := mocks.NewMockPacketFilter(ctrl) - filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) + filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true) wrapped := newDeviceFilter(tun) wrapped.filter = filter diff --git a/client/iface/mocks/filter.go b/client/iface/mocks/filter.go index f00024e38..faac55d68 100644 --- a/client/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -50,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 } // DropIncoming mocks base method. -func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { +func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DropIncoming", arg0) + ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1) ret0, _ := ret[0].(bool) return ret0 } // DropIncoming indicates an expected call of DropIncoming. -func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { +func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1) } // DropOutgoing mocks base method. -func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { +func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DropOutgoing", arg0) + ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1) ret0, _ := ret[0].(bool) return ret0 } // DropOutgoing indicates an expected call of DropOutgoing. -func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { +func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1) } // RemovePacketHook mocks base method. diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 7c75f6bed..80db2561b 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -458,7 +458,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { } packetfilter := pfmock.NewMockPacketFilter(ctrl) - packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes() + packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes() packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().SetNetwork(ipNet) diff --git a/client/internal/netflow/types/types.go b/client/internal/netflow/types/types.go index a0f2eb95e..f57fca1c9 100644 --- a/client/internal/netflow/types/types.go +++ b/client/internal/netflow/types/types.go @@ -2,6 +2,7 @@ package types import ( "net/netip" + "strconv" "time" "github.com/google/uuid" @@ -27,8 +28,10 @@ func (p Protocol) String() string { return "TCP" case 17: return "UDP" + case 132: + return "SCTP" default: - return "unknown" + return strconv.FormatUint(uint64(p), 10) } }