From f57bc604a8d105bdf8910fe907cbdb9512d0f70e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 23 Dec 2024 15:59:55 +0100 Subject: [PATCH] Add stats collection --- client/firewall/iptables/manager_linux.go | 5 + client/firewall/manager/firewall.go | 3 + client/firewall/manager/stats.go | 107 +++++++++++ client/firewall/nftables/manager_linux.go | 5 + client/firewall/uspfilter/allow_netbird.go | 6 +- .../uspfilter/allow_netbird_windows.go | 6 +- .../uspfilter/conntrack/common_test.go | 12 +- client/firewall/uspfilter/conntrack/icmp.go | 23 ++- .../firewall/uspfilter/conntrack/icmp_test.go | 10 +- client/firewall/uspfilter/conntrack/stats.go | 172 ++++++++++++++++++ client/firewall/uspfilter/conntrack/tcp.go | 35 +++- .../firewall/uspfilter/conntrack/tcp_test.go | 72 ++++---- client/firewall/uspfilter/conntrack/udp.go | 21 ++- .../firewall/uspfilter/conntrack/udp_test.go | 24 +-- client/firewall/uspfilter/uspfilter.go | 83 +++++---- .../uspfilter/uspfilter_bench_test.go | 109 +++++++++++ client/firewall/uspfilter/uspfilter_test.go | 4 +- client/internal/engine.go | 12 +- client/server/debug.go | 46 ++++- 19 files changed, 630 insertions(+), 125 deletions(-) create mode 100644 client/firewall/manager/stats.go create mode 100644 client/firewall/uspfilter/conntrack/stats.go diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 0e1e5836f..deb83d8c4 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -215,6 +215,11 @@ func (m *Manager) AllowNetbird() error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } +// CollectStats returns connection tracking statistics +func (m *Manager) CollectStats() []*firewall.FlowStats { + return nil +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 9391b47ec..59b1d9cfa 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -100,6 +100,9 @@ type Manager interface { // Flush the changes to firewall controller Flush() error + + // CollectStats returns the statistics of the firewall manager + CollectStats() []*FlowStats } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/manager/stats.go b/client/firewall/manager/stats.go new file mode 100644 index 000000000..d10f08d5e --- /dev/null +++ b/client/firewall/manager/stats.go @@ -0,0 +1,107 @@ +package manager + +import ( + "encoding/json" + "net" + "slices" + "strconv" + "sync/atomic" + "time" +) + +const ( + DirectionInbound Direction = 0 + DirectionOutbound Direction = 1 +) + +type Direction uint8 + +func (d Direction) String() string { + switch d { + case DirectionInbound: + return "inbound" + case DirectionOutbound: + return "outbound" + default: + return "unknown" + } +} + +// FlowStats tracks statistics for an individual connection +type FlowStats struct { + StartTime time.Time + LastSeen time.Time + BytesIn atomic.Uint64 + BytesOut atomic.Uint64 + PacketsIn atomic.Uint64 + PacketsOut atomic.Uint64 + Protocol uint8 + Direction Direction + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 +} + +func (f *FlowStats) Clone() *FlowStats { + flowCopy := FlowStats{ + StartTime: f.StartTime, + LastSeen: f.LastSeen, + Protocol: f.Protocol, + Direction: f.Direction, + SourceIP: slices.Clone(f.SourceIP), + DestIP: slices.Clone(f.DestIP), + SourcePort: f.SourcePort, + DestPort: f.DestPort, + } + flowCopy.BytesIn.Store(f.BytesIn.Load()) + flowCopy.BytesOut.Store(f.BytesOut.Load()) + flowCopy.PacketsIn.Store(f.PacketsIn.Load()) + flowCopy.PacketsOut.Store(f.PacketsOut.Load()) + + return &flowCopy +} + +// MarshalJSON implements json.Marshaler interface +func (f *FlowStats) MarshalJSON() ([]byte, error) { + return json.Marshal(&struct { + StartTime time.Time + LastSeen time.Time + BytesIn uint64 + BytesOut uint64 + PacketsIn uint64 + PacketsOut uint64 + Protocol Protocol + Direction string + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + }{ + StartTime: f.StartTime, + LastSeen: f.LastSeen, + BytesIn: f.BytesIn.Load(), + BytesOut: f.BytesOut.Load(), + PacketsIn: f.PacketsIn.Load(), + PacketsOut: f.PacketsOut.Load(), + Protocol: protoFromInt(f.Protocol), + Direction: f.Direction.String(), + SourceIP: f.SourceIP, + DestIP: f.DestIP, + SourcePort: f.SourcePort, + DestPort: f.DestPort, + }) +} + +func protoFromInt(p uint8) Protocol { + switch p { + case 6: + return ProtocolTCP + case 17: + return ProtocolUDP + case 1: + return ProtocolICMP + default: + return Protocol(strconv.Itoa(int(p))) + } +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 8e1aa0d80..41cba52c2 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -323,6 +323,11 @@ func (m *Manager) Flush() error { return m.aclManager.Flush() } +// CollectStats returns connection tracking statistics +func (m *Manager) CollectStats() []*firewall.FlowStats { + return nil +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index cc0792255..befd8be1d 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -17,17 +17,17 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, nil) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, nil) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, nil) } if m.nativeFirewall != nil { diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 0d55d6268..c04ec39f5 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -29,17 +29,17 @@ func (m *Manager) Reset(*statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, nil) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, nil) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, nil) } if !isWindowsFirewallReachable() { diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index 72d006def..77c7ae89b 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -64,7 +64,7 @@ func BenchmarkAtomicOperations(b *testing.B) { // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() // Generate different IPs @@ -79,17 +79,17 @@ func BenchmarkMemoryPressure(b *testing.B) { for i := 0; i < b.N; i++ { srcIdx := i % len(srcIPs) dstIdx := (i + 1) % len(dstIPs) - tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, nil) // Simulate some valid inbound packets if i%3 == 0 { - tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, nil) } } }) b.Run("UDPHighLoad", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, nil) defer tracker.Close() // Generate different IPs @@ -104,11 +104,11 @@ func BenchmarkMemoryPressure(b *testing.B) { for i := 0; i < b.N; i++ { srcIdx := i % len(srcIPs) dstIdx := (i + 1) % len(dstIPs) - tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, nil) // Simulate some valid inbound packets if i%3 == 0 { - tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), nil) } } }) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index e0a971678..5d7fb1aef 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -6,6 +6,8 @@ import ( "time" "github.com/google/gopacket/layers" + + fw "github.com/netbirdio/netbird/client/firewall/manager" ) const ( @@ -39,10 +41,11 @@ type ICMPTracker struct { mutex sync.RWMutex done chan struct{} ipPool *PreallocatedIPs + stats *Stats } // NewICMPTracker creates a new ICMP connection tracker -func NewICMPTracker(timeout time.Duration) *ICMPTracker { +func NewICMPTracker(timeout time.Duration, stats *Stats) *ICMPTracker { if timeout == 0 { timeout = DefaultICMPTimeout } @@ -53,6 +56,7 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker { cleanupTicker: time.NewTicker(ICMPCleanupInterval), done: make(chan struct{}), ipPool: NewPreallocatedIPs(), + stats: stats, } go tracker.cleanupRoutine() @@ -60,7 +64,7 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker { } // TrackOutbound records an outbound ICMP Echo Request -func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { +func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, packetData []byte) { key := makeICMPKey(srcIP, dstIP, id, seq) now := time.Now().UnixNano() @@ -83,14 +87,22 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u conn.lastSeen.Store(now) conn.established.Store(true) t.connections[key] = conn + + if t.stats != nil { + t.stats.TrackNewConnection(1, srcIP, dstIP, 0, 0, fw.DirectionOutbound) + } } t.mutex.Unlock() + if t.stats != nil { + key := makeConnKey(srcIP, dstIP, 0, 0) + t.stats.TrackPacket(1, false, uint64(len(packetData)), false, key) + } conn.lastSeen.Store(now) } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request -func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { +func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8, packetData []byte) bool { switch icmpType { case uint8(layers.ICMPv4TypeDestinationUnreachable), uint8(layers.ICMPv4TypeTimeExceeded): @@ -115,6 +127,11 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq return false } + if t.stats != nil { + key := makeConnKey(srcIP, dstIP, 0, 0) + t.stats.TrackPacket(1, false, uint64(len(packetData)), true, key) + } + return conn.IsEstablished() && ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index 21176e719..eb727af99 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -7,7 +7,7 @@ import ( func BenchmarkICMPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout) + tracker := NewICMPTracker(DefaultICMPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -15,12 +15,12 @@ func BenchmarkICMPTracker(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), nil) } }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout) + tracker := NewICMPTracker(DefaultICMPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -28,12 +28,12 @@ func BenchmarkICMPTracker(b *testing.B) { // Pre-populate some connections for i := 0; i < 1000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), nil) } b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) + tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0, nil) } }) } diff --git a/client/firewall/uspfilter/conntrack/stats.go b/client/firewall/uspfilter/conntrack/stats.go new file mode 100644 index 000000000..61c06b9af --- /dev/null +++ b/client/firewall/uspfilter/conntrack/stats.go @@ -0,0 +1,172 @@ +package conntrack + +import ( + "net" + "slices" + "sync" + "sync/atomic" + "time" + + fw "github.com/netbirdio/netbird/client/firewall/manager" +) + +// Stats represents connection tracking statistics +type Stats struct { + TotalConnsCreated atomic.Uint64 + TotalConnsTimedOut atomic.Uint64 + TotalPacketsDropped atomic.Uint64 + ActiveConns atomic.Int64 + + TCPConns atomic.Int64 + UDPConns atomic.Int64 + ICMPConns atomic.Int64 + + TCPStateStats struct { + SynReceived atomic.Uint64 + Established atomic.Uint64 + FinWait atomic.Uint64 + TimeWait atomic.Uint64 + InvalidStates atomic.Uint64 + } + + PacketStats struct { + TCPPackets atomic.Uint64 + UDPPackets atomic.Uint64 + ICMPPackets atomic.Uint64 + } + + flowMutex sync.RWMutex + flows map[ConnKey]*fw.FlowStats +} + +// NewStats creates a new Stats instance +func NewStats() *Stats { + return &Stats{ + flows: make(map[ConnKey]*fw.FlowStats), + } +} + +// TrackNewConnection records a new connection +func (s *Stats) TrackNewConnection(proto uint8, srcIP net.IP, dstIP net.IP, srcPort, dstPort uint16, direction fw.Direction) { + s.TotalConnsCreated.Add(1) + s.ActiveConns.Add(1) + + switch proto { + case 6: // TCP + s.TCPConns.Add(1) + case 17: // UDP + s.UDPConns.Add(1) + case 1: // ICMP + s.ICMPConns.Add(1) + } + + flow := &fw.FlowStats{ + StartTime: time.Now(), + LastSeen: time.Now(), + Protocol: proto, + Direction: direction, + SourceIP: slices.Clone(srcIP), + DestIP: slices.Clone(dstIP), + SourcePort: srcPort, + DestPort: dstPort, + } + + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + s.flowMutex.Lock() + s.flows[key] = flow + s.flowMutex.Unlock() +} + +// TrackConnectionClosed records a connection closure +func (s *Stats) TrackConnectionClosed(proto uint8, timedOut bool, key ConnKey) { + s.ActiveConns.Add(-1) + + if timedOut { + s.TotalConnsTimedOut.Add(1) + } + + switch proto { + case 6: // TCP + s.TCPConns.Add(-1) + case 17: // UDP + s.UDPConns.Add(-1) + case 1: // ICMP + s.ICMPConns.Add(-1) + } + + s.flowMutex.Lock() + delete(s.flows, key) + s.flowMutex.Unlock() +} + +// TrackPacket records packet statistics +func (s *Stats) TrackPacket(proto uint8, dropped bool, bytes uint64, isInbound bool, key ConnKey) { + if dropped { + s.TotalPacketsDropped.Add(1) + return + } + + switch proto { + case 6: // TCP + s.PacketStats.TCPPackets.Add(1) + case 17: // UDP + s.PacketStats.UDPPackets.Add(1) + case 1: // ICMP + s.PacketStats.ICMPPackets.Add(1) + } + + s.flowMutex.RLock() + if flow, exists := s.flows[key]; exists { + if isInbound { + flow.BytesIn.Add(bytes) + flow.PacketsIn.Add(1) + } else { + flow.BytesOut.Add(bytes) + flow.PacketsOut.Add(1) + } + flow.LastSeen = time.Now() + } + s.flowMutex.RUnlock() +} + +// TrackTCPState updates TCP state statistics +func (s *Stats) TrackTCPState(newState TCPState) { + switch newState { + case TCPStateSynReceived: + s.TCPStateStats.SynReceived.Add(1) + case TCPStateEstablished: + s.TCPStateStats.Established.Add(1) + case TCPStateFinWait1, TCPStateFinWait2: + s.TCPStateStats.FinWait.Add(1) + case TCPStateTimeWait: + s.TCPStateStats.TimeWait.Add(1) + default: + s.TCPStateStats.InvalidStates.Add(1) + } +} + +// GetFlowSnapshot returns a copy of current flow statistics if enabled +func (s *Stats) GetFlowSnapshot() []*fw.FlowStats { + s.flowMutex.RLock() + defer s.flowMutex.RUnlock() + + snapshot := make([]*fw.FlowStats, 0, len(s.flows)) + for _, flow := range s.flows { + snapshot = append(snapshot, flow.Clone()) + } + return snapshot +} + +// CleanupFlows removes flow entries older than the specified duration if enabled +func (s *Stats) CleanupFlows(maxAge time.Duration) { + threshold := time.Now().Add(-maxAge) + + s.flowMutex.Lock() + defer s.flowMutex.Unlock() + + for key, flow := range s.flows { + if flow.LastSeen.Before(threshold) { + delete(s.flows, key) + } + } +} diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index e8d20f41c..5506c0597 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -6,6 +6,8 @@ import ( "net" "sync" "time" + + fw "github.com/netbirdio/netbird/client/firewall/manager" ) const ( @@ -72,16 +74,18 @@ type TCPTracker struct { done chan struct{} timeout time.Duration ipPool *PreallocatedIPs + stats *Stats } // NewTCPTracker creates a new TCP connection tracker -func NewTCPTracker(timeout time.Duration) *TCPTracker { +func NewTCPTracker(timeout time.Duration, stats *Stats) *TCPTracker { tracker := &TCPTracker{ connections: make(map[ConnKey]*TCPConnTrack), cleanupTicker: time.NewTicker(TCPCleanupInterval), done: make(chan struct{}), timeout: timeout, ipPool: NewPreallocatedIPs(), + stats: stats, } go tracker.cleanupRoutine() @@ -89,15 +93,13 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker { } // TrackOutbound processes an outbound TCP packet and updates connection state -func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { - // Create key before lock +func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, packetData []byte) { key := makeConnKey(srcIP, dstIP, srcPort, dstPort) now := time.Now().UnixNano() t.mutex.Lock() conn, exists := t.connections[key] if !exists { - // Use preallocated IPs srcIPCopy := t.ipPool.Get() dstIPCopy := t.ipPool.Get() copyIP(srcIPCopy, srcIP) @@ -115,18 +117,30 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d conn.lastSeen.Store(now) conn.established.Store(false) t.connections[key] = conn + + if t.stats != nil { + t.stats.TrackNewConnection(6, srcIP, dstIP, srcPort, dstPort, fw.DirectionOutbound) + } } t.mutex.Unlock() - // Lock individual connection for state update conn.Lock() + oldState := conn.State t.updateState(conn, flags, true) + if oldState != conn.State && t.stats != nil { + t.stats.TrackTCPState(conn.State) + } conn.Unlock() + + if t.stats != nil { + t.stats.TrackPacket(6, false, uint64(len(packetData)), false, key) + } conn.lastSeen.Store(now) } // IsValidInbound checks if an inbound TCP packet matches a tracked connection -func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { +func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, packetData []byte) bool { + if !isValidFlagCombination(flags) { return false } @@ -156,6 +170,11 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, t.connections[key] = conn } t.mutex.Unlock() + + if t.stats != nil { + t.stats.TrackPacket(6, false, uint64(len(packetData)), true, key) + } + return true } @@ -169,6 +188,10 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } + if t.stats != nil { + t.stats.TrackPacket(6, false, uint64(len(packetData)), true, key) + } + // Handle RST packets if flags&TCPRst != 0 { conn.Lock() diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 3933c8889..f3c2994bf 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -9,7 +9,7 @@ import ( ) func TestTCPStateMachine(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) + isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, nil) require.Equal(t, !tt.wantDrop, isValid, tt.desc) }) } @@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) { t.Helper() // Send initial SYN - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil) // Receive SYN-ACK - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil) require.True(t, valid, "SYN-ACK should be allowed") // Send ACK - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil) // Test data transfer - valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, nil) require.True(t, valid, "Data should be allowed after handshake") }, }, @@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) { establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) // Send FIN - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, nil) // Receive ACK for FIN - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, nil) require.True(t, valid, "ACK for FIN should be allowed") // Receive FIN from other side - valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, nil) require.True(t, valid, "FIN should be allowed") // Send final ACK - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil) }, }, { @@ -122,11 +122,11 @@ func TestTCPStateMachine(t *testing.T) { establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) // Receive RST - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil) require.True(t, valid, "RST should be allowed for established connection") // Verify connection is closed - valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, nil) t.Helper() require.False(t, valid, "Data should be blocked after RST") @@ -141,13 +141,13 @@ func TestTCPStateMachine(t *testing.T) { establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) // Both sides send FIN+ACK - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, nil) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, nil) require.True(t, valid, "Simultaneous FIN should be allowed") // Both sides send final ACK - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) - valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, nil) require.True(t, valid, "Final ACKs should be allowed") }, }, @@ -157,7 +157,7 @@ func TestTCPStateMachine(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Helper() - tracker = NewTCPTracker(DefaultTCPTimeout) + tracker = NewTCPTracker(DefaultTCPTimeout, nil) tt.test(t) }) } @@ -165,7 +165,7 @@ func TestTCPStateMachine(t *testing.T) { } func TestRSTHandling(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -184,12 +184,12 @@ func TestRSTHandling(t *testing.T) { name: "RST in established", setupState: func() { // Establish connection first - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) - tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil) }, sendRST: func() { - tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil) }, wantValid: true, desc: "Should accept RST for established connection", @@ -198,7 +198,7 @@ func TestRSTHandling(t *testing.T) { name: "RST without connection", setupState: func() {}, sendRST: func() { - tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil) }, wantValid: false, desc: "Should reject RST without connection", @@ -226,17 +226,17 @@ func TestRSTHandling(t *testing.T) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { t.Helper() - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil) - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil) require.True(t, valid, "SYN-ACK should be allowed") - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil) } func BenchmarkTCPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -244,12 +244,12 @@ func BenchmarkTCPTracker(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, nil) } }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -257,17 +257,17 @@ func BenchmarkTCPTracker(b *testing.B) { // Pre-populate some connections for i := 0; i < 1000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, nil) } b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, nil) } }) b.Run("ConcurrentAccess", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -277,9 +277,9 @@ func BenchmarkTCPTracker(b *testing.B) { i := 0 for pb.Next() { if i%2 == 0 { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, nil) } else { - tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, nil) } i++ } @@ -290,14 +290,14 @@ func BenchmarkTCPTracker(b *testing.B) { // Benchmark connection cleanup func BenchmarkCleanup(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) { - tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing + tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing defer tracker.Close() // Pre-populate with expired connections srcIP := net.ParseIP("192.168.1.1") dstIP := net.ParseIP("192.168.1.2") for i := 0; i < 10000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, nil) } // Wait for connections to expire diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index a969a4e84..628b99986 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -4,6 +4,8 @@ import ( "net" "sync" "time" + + fw "github.com/netbirdio/netbird/client/firewall/manager" ) const ( @@ -26,10 +28,11 @@ type UDPTracker struct { mutex sync.RWMutex done chan struct{} ipPool *PreallocatedIPs + stats *Stats } // NewUDPTracker creates a new UDP connection tracker -func NewUDPTracker(timeout time.Duration) *UDPTracker { +func NewUDPTracker(timeout time.Duration, stats *Stats) *UDPTracker { if timeout == 0 { timeout = DefaultUDPTimeout } @@ -40,6 +43,7 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker { cleanupTicker: time.NewTicker(UDPCleanupInterval), done: make(chan struct{}), ipPool: NewPreallocatedIPs(), + stats: stats, } go tracker.cleanupRoutine() @@ -47,7 +51,7 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker { } // TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { +func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, packetData []byte) { key := makeConnKey(srcIP, dstIP, srcPort, dstPort) now := time.Now().UnixNano() @@ -70,14 +74,21 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d conn.lastSeen.Store(now) conn.established.Store(true) t.connections[key] = conn + + if t.stats != nil { + t.stats.TrackNewConnection(17, srcIP, dstIP, srcPort, dstPort, fw.DirectionOutbound) + } } t.mutex.Unlock() + if t.stats != nil { + t.stats.TrackPacket(17, false, uint64(len(packetData)), false, key) + } conn.lastSeen.Store(now) } // IsValidInbound checks if an inbound packet matches a tracked connection -func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { +func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, packetData []byte) bool { key := makeConnKey(dstIP, srcIP, dstPort, srcPort) t.mutex.RLock() @@ -92,6 +103,10 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } + if t.stats != nil { + t.stats.TrackPacket(17, false, uint64(len(packetData)), true, key) + } + return conn.IsEstablished() && ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 671721890..a48a7ccad 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tracker := NewUDPTracker(tt.timeout) + tracker := NewUDPTracker(tt.timeout, nil) assert.NotNil(t, tracker) assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.NotNil(t, tracker.connections) @@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) { } func TestUDPTracker_TrackOutbound(t *testing.T) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -48,7 +48,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { srcPort := uint16(12345) dstPort := uint16(53) - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, nil) // Verify connection was tracked key := makeConnKey(srcIP, dstIP, srcPort, dstPort) @@ -63,7 +63,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { } func TestUDPTracker_IsValidInbound(t *testing.T) { - tracker := NewUDPTracker(1 * time.Second) + tracker := NewUDPTracker(1*time.Second, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -72,7 +72,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) { dstPort := uint16(53) // Track outbound connection - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, nil) tests := []struct { name string @@ -144,7 +144,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) { if tt.sleep > 0 { time.Sleep(tt.sleep) } - got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) + got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, nil) assert.Equal(t, tt.want, got) }) } @@ -189,7 +189,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { } for _, conn := range connections { - tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) + tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, nil) } // Verify initial connections @@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { func BenchmarkUDPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -219,12 +219,12 @@ func BenchmarkUDPTracker(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, nil) } }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, nil) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -232,12 +232,12 @@ func BenchmarkUDPTracker(b *testing.B) { // Pre-populate some connections for i := 0; i < 1000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, nil) } b.ResetTimer() for i := 0; i < b.N; i++ { - tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), nil) } }) } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 24cfd6e96..ae2a768ea 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -22,7 +22,10 @@ import ( const layerTypeAll = 0 -const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" +const ( + EnvDisableConntrack = "NB_DISABLE_CONNTRACK" + EnvEnableStats = "NB_ENABLE_CONNTRACK_STATS" +) var ( errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") @@ -52,6 +55,9 @@ type Manager struct { udpTracker *conntrack.UDPTracker icmpTracker *conntrack.ICMPTracker tcpTracker *conntrack.TCPTracker + + statsEnabled bool + stats *conntrack.Stats } // decoder for packages @@ -84,6 +90,7 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager func create(iface IFaceMapper) (*Manager, error) { disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) + enableStats, _ := strconv.ParseBool(os.Getenv(EnvEnableStats)) m := &Manager{ decoders: sync.Pool{ @@ -103,15 +110,21 @@ func create(iface IFaceMapper) (*Manager, error) { incomingRules: make(map[string]RuleSet), wgIface: iface, stateful: !disableConntrack, + statsEnabled: enableStats, + } + + if enableStats { + m.stats = conntrack.NewStats() + log.Info("connection tracking statistics enabled") } // Only initialize trackers if stateful mode is enabled if disableConntrack { log.Info("conntrack is disabled") } else { - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.stats) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.stats) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.stats) } if err := iface.SetFilter(m); err != nil { @@ -304,7 +317,10 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { if d.decoded[1] == layers.LayerTypeUDP { // Track UDP state only if enabled if m.stateful { - m.trackUDPOutbound(d, srcIP, dstIP) + m.udpTracker.TrackOutbound(srcIP, dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + packetData) } return m.checkUDPHooks(d, dstIP, packetData) } @@ -313,9 +329,16 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { if m.stateful { switch d.decoded[1] { case layers.LayerTypeTCP: - m.trackTCPOutbound(d, srcIP, dstIP) + m.tcpTracker.TrackOutbound(srcIP, dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + getTCPFlags(&d.tcp), + packetData) case layers.LayerTypeICMPv4: - m.trackICMPOutbound(d, srcIP, dstIP) + m.icmpTracker.TrackOutbound(srcIP, dstIP, + d.icmp4.Id, + d.icmp4.Seq, + packetData) } } @@ -333,17 +356,6 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { } } -func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) { - flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound( - srcIP, - dstIP, - uint16(d.tcp.SrcPort), - uint16(d.tcp.DstPort), - flags, - ) -} - func getTCPFlags(tcp *layers.TCP) uint8 { var flags uint8 if tcp.SYN { @@ -367,15 +379,6 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) { - m.udpTracker.TrackOutbound( - srcIP, - dstIP, - uint16(d.udp.SrcPort), - uint16(d.udp.DstPort), - ) -} - func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { if rules, exists := m.outgoingRules[ipKey]; exists { @@ -389,17 +392,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo return false } -func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { - if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { - m.icmpTracker.TrackOutbound( - srcIP, - dstIP, - d.icmp4.Id, - d.icmp4.Seq, - ) - } -} - // dropFilter implements filtering logic for incoming packets func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { m.mutex.RLock() @@ -423,7 +415,7 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { } // Check connection state only if enabled - if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { + if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, packetData) { return false } @@ -447,7 +439,7 @@ func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool { return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP) } -func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { +func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { switch d.decoded[1] { case layers.LayerTypeTCP: return m.tcpTracker.IsValidInbound( @@ -456,6 +448,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), getTCPFlags(&d.tcp), + packetData, ) case layers.LayerTypeUDP: @@ -464,6 +457,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), + packetData, ) case layers.LayerTypeICMPv4: @@ -473,6 +467,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode.Type(), + packetData, ) // TODO: ICMPv6 @@ -612,3 +607,11 @@ func (m *Manager) RemovePacketHook(hookID string) error { } return fmt.Errorf("hook with given id not found") } + +// CollectStats returns connection tracking statistics +func (m *Manager) CollectStats() []*firewall.FlowStats { + if m.stats == nil { + return nil + } + return m.stats.GetFlowSnapshot() +} diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 3c661e71c..f6c5e0f61 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -5,6 +5,7 @@ import ( "math/rand" "net" "os" + "strconv" "strings" "testing" @@ -965,6 +966,114 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { } } +func BenchmarkFirewallStats(b *testing.B) { + scenarios := []struct { + name string + stats bool + longLived bool + conns int + }{ + {"nostats_short_100", false, false, 100}, + {"stats_short_100", true, false, 100}, + {"nostats_long_100", false, true, 100}, + {"stats_long_100", true, true, 100}, + {"nostats_short_1000", false, false, 1000}, + {"stats_short_1000", true, false, 1000}, + {"nostats_long_1000", false, true, 1000}, + {"stats_long_1000", true, true, 1000}, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + b.Setenv(EnvEnableStats, strconv.FormatBool(sc.stats)) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Generate test IPs + srcIPs := make([]net.IP, sc.conns) + dstIPs := make([]net.IP, sc.conns) + for i := 0; i < sc.conns; i++ { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + + // Pre-generate packets + inPackets := make([][]byte, sc.conns) + outPackets := make([][]byte, sc.conns) + for i := 0; i < sc.conns; i++ { + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + + if sc.longLived { + // Establish connection + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + synAck := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + + manager.processOutgoingHooks(syn) + manager.dropFilter(synAck, manager.incomingRules) + manager.processOutgoingHooks(ack) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + connIdx := i % sc.conns + + if !sc.longLived { + // New connection each time + syn := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx], + uint16(1024+connIdx), 80, uint16(conntrack.TCPSyn)) + synAck := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx], + 80, uint16(1024+connIdx), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + ack := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx], + uint16(1024+connIdx), 80, uint16(conntrack.TCPAck)) + + manager.processOutgoingHooks(syn) + manager.dropFilter(synAck, manager.incomingRules) + manager.processOutgoingHooks(ack) + } + + // Data transfer + manager.processOutgoingHooks(outPackets[connIdx]) + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + + if !sc.longLived { + // Tear down + finClient := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx], + uint16(1024+connIdx), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)) + ackServer := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx], + 80, uint16(1024+connIdx), uint16(conntrack.TCPAck)) + finServer := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx], + 80, uint16(1024+connIdx), uint16(conntrack.TCPFin|conntrack.TCPAck)) + ackClient := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx], + uint16(1024+connIdx), 80, uint16(conntrack.TCPAck)) + + manager.processOutgoingHooks(finClient) + manager.dropFilter(ackServer, manager.incomingRules) + manager.dropFilter(finServer, manager.incomingRules) + manager.processOutgoingHooks(ackClient) + } + } + }) + } +} + // generateTCPPacketWithFlags creates a TCP packet with specific flags func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { b.Helper() diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index d3563e6f2..87d41fae1 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -400,7 +400,7 @@ func TestProcessOutgoingHooks(t *testing.T) { Mask: net.CIDRMask(16, 32), } manager.udpTracker.Close() - manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) + manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil) defer func() { require.NoError(t, manager.Reset(nil)) }() @@ -518,7 +518,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { } manager.udpTracker.Close() // Close the existing tracker - manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) + manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil) manager.decoders = sync.Pool{ New: func() any { d := &decoder{ diff --git a/client/internal/engine.go b/client/internal/engine.go index 042d384dc..050e81dc6 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -23,7 +23,7 @@ import ( "google.golang.org/protobuf/proto" "github.com/netbirdio/netbird/client/firewall" - "github.com/netbirdio/netbird/client/firewall/manager" + firewallmanager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" @@ -158,7 +158,7 @@ type Engine struct { statusRecorder *peer.Status - firewall manager.Manager + firewall firewallmanager.Manager routeManager routemanager.Manager acl acl.Manager dnsForwardMgr *dnsfwd.Manager @@ -1576,6 +1576,14 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { return nm, nil } +// GetFirewallStats returns the firewall stats +func (e *Engine) GetFirewallStats() []*firewallmanager.FlowStats { + if e.firewall != nil { + return e.firewall.CollectStats() + } + return nil +} + // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { if !enabled { diff --git a/client/server/debug.go b/client/server/debug.go index 9dfde0367..a33c65279 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -44,6 +44,7 @@ iptables.txt: Anonymized iptables rules with packet counters, if --system-info f nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules. +firewall_stats.json: Anonymized firewall statistics of the NetBird client. state.json: Anonymized client state dump containing netbird states. @@ -139,10 +140,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( s.mutex.Lock() defer s.mutex.Unlock() - if s.logFile == "console" { - return nil, fmt.Errorf("log file is set to console, cannot create debug bundle") - } - bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip") if err != nil { return nil, fmt.Errorf("create zip file: %w", err) @@ -202,6 +199,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques return fmt.Errorf("add network map: %w", err) } + if err := s.addFirewallStats(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add firewall stats to debug bundle: %v", err) + } + if err := s.addStateFile(req, anonymizer, archive); err != nil { log.Errorf("Failed to add state file to debug bundle: %v", err) } @@ -356,6 +357,43 @@ func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonym return nil } +func (s *Server) addFirewallStats(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + if s.connectClient == nil || s.connectClient.Engine() == nil { + return nil + } + + stats := s.connectClient.Engine().GetFirewallStats() + if stats == nil { + return nil + } + + if req.GetAnonymize() { + for _, stat := range stats { + if stat.SourceIP != nil { + if ip, ok := netip.AddrFromSlice(stat.SourceIP); ok { + stat.SourceIP = anonymizer.AnonymizeIP(ip).AsSlice() + } + } + if stat.DestIP != nil { + if ip, ok := netip.AddrFromSlice(stat.DestIP); ok { + stat.DestIP = anonymizer.AnonymizeIP(ip).AsSlice() + } + } + } + } + + jsonBytes, err := json.MarshalIndent(stats, "", " ") + if err != nil { + return fmt.Errorf("marshal firewall stats: %w", err) + } + + if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "firewall_stats.json"); err != nil { + return fmt.Errorf("add firewall stats to zip: %w", err) + } + + return nil +} + func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { path := statemanager.GetDefaultStatePath() if path == "" {