Add stats collection

This commit is contained in:
Viktor Liu 2024-12-23 15:59:55 +01:00
parent ad9f044aad
commit f57bc604a8
19 changed files with 630 additions and 125 deletions

View File

@ -215,6 +215,11 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// CollectStats returns connection tracking statistics
func (m *Manager) CollectStats() []*firewall.FlowStats {
return nil
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@ -100,6 +100,9 @@ type Manager interface {
// Flush the changes to firewall controller
Flush() error
// CollectStats returns the statistics of the firewall manager
CollectStats() []*FlowStats
}
func GenKey(format string, pair RouterPair) string {

View File

@ -0,0 +1,107 @@
package manager
import (
"encoding/json"
"net"
"slices"
"strconv"
"sync/atomic"
"time"
)
const (
DirectionInbound Direction = 0
DirectionOutbound Direction = 1
)
type Direction uint8
func (d Direction) String() string {
switch d {
case DirectionInbound:
return "inbound"
case DirectionOutbound:
return "outbound"
default:
return "unknown"
}
}
// FlowStats tracks statistics for an individual connection
type FlowStats struct {
StartTime time.Time
LastSeen time.Time
BytesIn atomic.Uint64
BytesOut atomic.Uint64
PacketsIn atomic.Uint64
PacketsOut atomic.Uint64
Protocol uint8
Direction Direction
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
}
func (f *FlowStats) Clone() *FlowStats {
flowCopy := FlowStats{
StartTime: f.StartTime,
LastSeen: f.LastSeen,
Protocol: f.Protocol,
Direction: f.Direction,
SourceIP: slices.Clone(f.SourceIP),
DestIP: slices.Clone(f.DestIP),
SourcePort: f.SourcePort,
DestPort: f.DestPort,
}
flowCopy.BytesIn.Store(f.BytesIn.Load())
flowCopy.BytesOut.Store(f.BytesOut.Load())
flowCopy.PacketsIn.Store(f.PacketsIn.Load())
flowCopy.PacketsOut.Store(f.PacketsOut.Load())
return &flowCopy
}
// MarshalJSON implements json.Marshaler interface
func (f *FlowStats) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
StartTime time.Time
LastSeen time.Time
BytesIn uint64
BytesOut uint64
PacketsIn uint64
PacketsOut uint64
Protocol Protocol
Direction string
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
}{
StartTime: f.StartTime,
LastSeen: f.LastSeen,
BytesIn: f.BytesIn.Load(),
BytesOut: f.BytesOut.Load(),
PacketsIn: f.PacketsIn.Load(),
PacketsOut: f.PacketsOut.Load(),
Protocol: protoFromInt(f.Protocol),
Direction: f.Direction.String(),
SourceIP: f.SourceIP,
DestIP: f.DestIP,
SourcePort: f.SourcePort,
DestPort: f.DestPort,
})
}
func protoFromInt(p uint8) Protocol {
switch p {
case 6:
return ProtocolTCP
case 17:
return ProtocolUDP
case 1:
return ProtocolICMP
default:
return Protocol(strconv.Itoa(int(p)))
}
}

View File

@ -323,6 +323,11 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush()
}
// CollectStats returns connection tracking statistics
func (m *Manager) CollectStats() []*firewall.FlowStats {
return nil
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@ -17,17 +17,17 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, nil)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, nil)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, nil)
}
if m.nativeFirewall != nil {

View File

@ -29,17 +29,17 @@ func (m *Manager) Reset(*statemanager.Manager) error {
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, nil)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, nil)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, nil)
}
if !isWindowsFirewallReachable() {

View File

@ -64,7 +64,7 @@ func BenchmarkAtomicOperations(b *testing.B) {
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
// Generate different IPs
@ -79,17 +79,17 @@ func BenchmarkMemoryPressure(b *testing.B) {
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, nil)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, nil)
}
}
})
b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
// Generate different IPs
@ -104,11 +104,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, nil)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), nil)
}
}
})

View File

@ -6,6 +6,8 @@ import (
"time"
"github.com/google/gopacket/layers"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
@ -39,10 +41,11 @@ type ICMPTracker struct {
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
stats *Stats
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
func NewICMPTracker(timeout time.Duration, stats *Stats) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}
@ -53,6 +56,7 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
stats: stats,
}
go tracker.cleanupRoutine()
@ -60,7 +64,7 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
}
// TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, packetData []byte) {
key := makeICMPKey(srcIP, dstIP, id, seq)
now := time.Now().UnixNano()
@ -83,14 +87,22 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
if t.stats != nil {
t.stats.TrackNewConnection(1, srcIP, dstIP, 0, 0, fw.DirectionOutbound)
}
}
t.mutex.Unlock()
if t.stats != nil {
key := makeConnKey(srcIP, dstIP, 0, 0)
t.stats.TrackPacket(1, false, uint64(len(packetData)), false, key)
}
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8, packetData []byte) bool {
switch icmpType {
case uint8(layers.ICMPv4TypeDestinationUnreachable),
uint8(layers.ICMPv4TypeTimeExceeded):
@ -115,6 +127,11 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
return false
}
if t.stats != nil {
key := makeConnKey(srcIP, dstIP, 0, 0)
t.stats.TrackPacket(1, false, uint64(len(packetData)), true, key)
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&

View File

@ -7,7 +7,7 @@ import (
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout)
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -15,12 +15,12 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), nil)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout)
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -28,12 +28,12 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), nil)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0, nil)
}
})
}

View File

@ -0,0 +1,172 @@
package conntrack
import (
"net"
"slices"
"sync"
"sync/atomic"
"time"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
// Stats represents connection tracking statistics
type Stats struct {
TotalConnsCreated atomic.Uint64
TotalConnsTimedOut atomic.Uint64
TotalPacketsDropped atomic.Uint64
ActiveConns atomic.Int64
TCPConns atomic.Int64
UDPConns atomic.Int64
ICMPConns atomic.Int64
TCPStateStats struct {
SynReceived atomic.Uint64
Established atomic.Uint64
FinWait atomic.Uint64
TimeWait atomic.Uint64
InvalidStates atomic.Uint64
}
PacketStats struct {
TCPPackets atomic.Uint64
UDPPackets atomic.Uint64
ICMPPackets atomic.Uint64
}
flowMutex sync.RWMutex
flows map[ConnKey]*fw.FlowStats
}
// NewStats creates a new Stats instance
func NewStats() *Stats {
return &Stats{
flows: make(map[ConnKey]*fw.FlowStats),
}
}
// TrackNewConnection records a new connection
func (s *Stats) TrackNewConnection(proto uint8, srcIP net.IP, dstIP net.IP, srcPort, dstPort uint16, direction fw.Direction) {
s.TotalConnsCreated.Add(1)
s.ActiveConns.Add(1)
switch proto {
case 6: // TCP
s.TCPConns.Add(1)
case 17: // UDP
s.UDPConns.Add(1)
case 1: // ICMP
s.ICMPConns.Add(1)
}
flow := &fw.FlowStats{
StartTime: time.Now(),
LastSeen: time.Now(),
Protocol: proto,
Direction: direction,
SourceIP: slices.Clone(srcIP),
DestIP: slices.Clone(dstIP),
SourcePort: srcPort,
DestPort: dstPort,
}
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
s.flowMutex.Lock()
s.flows[key] = flow
s.flowMutex.Unlock()
}
// TrackConnectionClosed records a connection closure
func (s *Stats) TrackConnectionClosed(proto uint8, timedOut bool, key ConnKey) {
s.ActiveConns.Add(-1)
if timedOut {
s.TotalConnsTimedOut.Add(1)
}
switch proto {
case 6: // TCP
s.TCPConns.Add(-1)
case 17: // UDP
s.UDPConns.Add(-1)
case 1: // ICMP
s.ICMPConns.Add(-1)
}
s.flowMutex.Lock()
delete(s.flows, key)
s.flowMutex.Unlock()
}
// TrackPacket records packet statistics
func (s *Stats) TrackPacket(proto uint8, dropped bool, bytes uint64, isInbound bool, key ConnKey) {
if dropped {
s.TotalPacketsDropped.Add(1)
return
}
switch proto {
case 6: // TCP
s.PacketStats.TCPPackets.Add(1)
case 17: // UDP
s.PacketStats.UDPPackets.Add(1)
case 1: // ICMP
s.PacketStats.ICMPPackets.Add(1)
}
s.flowMutex.RLock()
if flow, exists := s.flows[key]; exists {
if isInbound {
flow.BytesIn.Add(bytes)
flow.PacketsIn.Add(1)
} else {
flow.BytesOut.Add(bytes)
flow.PacketsOut.Add(1)
}
flow.LastSeen = time.Now()
}
s.flowMutex.RUnlock()
}
// TrackTCPState updates TCP state statistics
func (s *Stats) TrackTCPState(newState TCPState) {
switch newState {
case TCPStateSynReceived:
s.TCPStateStats.SynReceived.Add(1)
case TCPStateEstablished:
s.TCPStateStats.Established.Add(1)
case TCPStateFinWait1, TCPStateFinWait2:
s.TCPStateStats.FinWait.Add(1)
case TCPStateTimeWait:
s.TCPStateStats.TimeWait.Add(1)
default:
s.TCPStateStats.InvalidStates.Add(1)
}
}
// GetFlowSnapshot returns a copy of current flow statistics if enabled
func (s *Stats) GetFlowSnapshot() []*fw.FlowStats {
s.flowMutex.RLock()
defer s.flowMutex.RUnlock()
snapshot := make([]*fw.FlowStats, 0, len(s.flows))
for _, flow := range s.flows {
snapshot = append(snapshot, flow.Clone())
}
return snapshot
}
// CleanupFlows removes flow entries older than the specified duration if enabled
func (s *Stats) CleanupFlows(maxAge time.Duration) {
threshold := time.Now().Add(-maxAge)
s.flowMutex.Lock()
defer s.flowMutex.Unlock()
for key, flow := range s.flows {
if flow.LastSeen.Before(threshold) {
delete(s.flows, key)
}
}
}

View File

@ -6,6 +6,8 @@ import (
"net"
"sync"
"time"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
@ -72,16 +74,18 @@ type TCPTracker struct {
done chan struct{}
timeout time.Duration
ipPool *PreallocatedIPs
stats *Stats
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration) *TCPTracker {
func NewTCPTracker(timeout time.Duration, stats *Stats) *TCPTracker {
tracker := &TCPTracker{
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}),
timeout: timeout,
ipPool: NewPreallocatedIPs(),
stats: stats,
}
go tracker.cleanupRoutine()
@ -89,15 +93,13 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker {
}
// TrackOutbound processes an outbound TCP packet and updates connection state
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
// Create key before lock
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, packetData []byte) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
@ -115,18 +117,30 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.lastSeen.Store(now)
conn.established.Store(false)
t.connections[key] = conn
if t.stats != nil {
t.stats.TrackNewConnection(6, srcIP, dstIP, srcPort, dstPort, fw.DirectionOutbound)
}
}
t.mutex.Unlock()
// Lock individual connection for state update
conn.Lock()
oldState := conn.State
t.updateState(conn, flags, true)
if oldState != conn.State && t.stats != nil {
t.stats.TrackTCPState(conn.State)
}
conn.Unlock()
if t.stats != nil {
t.stats.TrackPacket(6, false, uint64(len(packetData)), false, key)
}
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, packetData []byte) bool {
if !isValidFlagCombination(flags) {
return false
}
@ -156,6 +170,11 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
t.connections[key] = conn
}
t.mutex.Unlock()
if t.stats != nil {
t.stats.TrackPacket(6, false, uint64(len(packetData)), true, key)
}
return true
}
@ -169,6 +188,10 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false
}
if t.stats != nil {
t.stats.TrackPacket(6, false, uint64(len(packetData)), true, key)
}
// Handle RST packets
if flags&TCPRst != 0 {
conn.Lock()

View File

@ -9,7 +9,7 @@ import (
)
func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, nil)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, nil)
require.True(t, valid, "Data should be allowed after handshake")
},
},
@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, nil)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, nil)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, nil)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
},
},
{
@ -122,11 +122,11 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
require.True(t, valid, "RST should be allowed for established connection")
// Verify connection is closed
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, nil)
t.Helper()
require.False(t, valid, "Data should be blocked after RST")
@ -141,13 +141,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, nil)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, nil)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, nil)
require.True(t, valid, "Final ACKs should be allowed")
},
},
@ -157,7 +157,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout)
tracker = NewTCPTracker(DefaultTCPTimeout, nil)
tt.test(t)
})
}
@ -165,7 +165,7 @@ func TestTCPStateMachine(t *testing.T) {
}
func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
@ -184,12 +184,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
},
wantValid: true,
desc: "Should accept RST for established connection",
@ -198,7 +198,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
},
wantValid: false,
desc: "Should reject RST without connection",
@ -226,17 +226,17 @@ func TestRSTHandling(t *testing.T) {
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -244,12 +244,12 @@ func BenchmarkTCPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, nil)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -257,17 +257,17 @@ func BenchmarkTCPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, nil)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, nil)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -277,9 +277,9 @@ func BenchmarkTCPTracker(b *testing.B) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, nil)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, nil)
}
i++
}
@ -290,14 +290,14 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, nil)
}
// Wait for connections to expire

View File

@ -4,6 +4,8 @@ import (
"net"
"sync"
"time"
fw "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
@ -26,10 +28,11 @@ type UDPTracker struct {
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
stats *Stats
}
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration) *UDPTracker {
func NewUDPTracker(timeout time.Duration, stats *Stats) *UDPTracker {
if timeout == 0 {
timeout = DefaultUDPTimeout
}
@ -40,6 +43,7 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
stats: stats,
}
go tracker.cleanupRoutine()
@ -47,7 +51,7 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, packetData []byte) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
@ -70,14 +74,21 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
if t.stats != nil {
t.stats.TrackNewConnection(17, srcIP, dstIP, srcPort, dstPort, fw.DirectionOutbound)
}
}
t.mutex.Unlock()
if t.stats != nil {
t.stats.TrackPacket(17, false, uint64(len(packetData)), false, key)
}
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, packetData []byte) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
@ -92,6 +103,10 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false
}
if t.stats != nil {
t.stats.TrackPacket(17, false, uint64(len(packetData)), true, key)
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&

View File

@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout)
tracker := NewUDPTracker(tt.timeout, nil)
assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
}
func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
@ -48,7 +48,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, nil)
// Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
@ -63,7 +63,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
}
func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1 * time.Second)
tracker := NewUDPTracker(1*time.Second, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
@ -72,7 +72,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, nil)
tests := []struct {
name string
@ -144,7 +144,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, nil)
assert.Equal(t, tt.want, got)
})
}
@ -189,7 +189,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, nil)
}
// Verify initial connections
@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -219,12 +219,12 @@ func BenchmarkUDPTracker(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, nil)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -232,12 +232,12 @@ func BenchmarkUDPTracker(b *testing.B) {
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, nil)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), nil)
}
})
}

View File

@ -22,7 +22,10 @@ import (
const layerTypeAll = 0
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
const (
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
EnvEnableStats = "NB_ENABLE_CONNTRACK_STATS"
)
var (
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
@ -52,6 +55,9 @@ type Manager struct {
udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
statsEnabled bool
stats *conntrack.Stats
}
// decoder for packages
@ -84,6 +90,7 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
func create(iface IFaceMapper) (*Manager, error) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
enableStats, _ := strconv.ParseBool(os.Getenv(EnvEnableStats))
m := &Manager{
decoders: sync.Pool{
@ -103,15 +110,21 @@ func create(iface IFaceMapper) (*Manager, error) {
incomingRules: make(map[string]RuleSet),
wgIface: iface,
stateful: !disableConntrack,
statsEnabled: enableStats,
}
if enableStats {
m.stats = conntrack.NewStats()
log.Info("connection tracking statistics enabled")
}
// Only initialize trackers if stateful mode is enabled
if disableConntrack {
log.Info("conntrack is disabled")
} else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.stats)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.stats)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.stats)
}
if err := iface.SetFilter(m); err != nil {
@ -304,7 +317,10 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
if d.decoded[1] == layers.LayerTypeUDP {
// Track UDP state only if enabled
if m.stateful {
m.trackUDPOutbound(d, srcIP, dstIP)
m.udpTracker.TrackOutbound(srcIP, dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
packetData)
}
return m.checkUDPHooks(d, dstIP, packetData)
}
@ -313,9 +329,16 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
if m.stateful {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.trackTCPOutbound(d, srcIP, dstIP)
m.tcpTracker.TrackOutbound(srcIP, dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
packetData)
case layers.LayerTypeICMPv4:
m.trackICMPOutbound(d, srcIP, dstIP)
m.icmpTracker.TrackOutbound(srcIP, dstIP,
d.icmp4.Id,
d.icmp4.Seq,
packetData)
}
}
@ -333,17 +356,6 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
}
}
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
flags,
)
}
func getTCPFlags(tcp *layers.TCP) uint8 {
var flags uint8
if tcp.SYN {
@ -367,15 +379,6 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
m.udpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
)
}
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
if rules, exists := m.outgoingRules[ipKey]; exists {
@ -389,17 +392,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
return false
}
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
m.icmpTracker.TrackOutbound(
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
)
}
}
// dropFilter implements filtering logic for incoming packets
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
m.mutex.RLock()
@ -423,7 +415,7 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
}
// Check connection state only if enabled
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, packetData) {
return false
}
@ -447,7 +439,7 @@ func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound(
@ -456,6 +448,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
packetData,
)
case layers.LayerTypeUDP:
@ -464,6 +457,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
packetData,
)
case layers.LayerTypeICMPv4:
@ -473,6 +467,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(),
packetData,
)
// TODO: ICMPv6
@ -612,3 +607,11 @@ func (m *Manager) RemovePacketHook(hookID string) error {
}
return fmt.Errorf("hook with given id not found")
}
// CollectStats returns connection tracking statistics
func (m *Manager) CollectStats() []*firewall.FlowStats {
if m.stats == nil {
return nil
}
return m.stats.GetFlowSnapshot()
}

View File

@ -5,6 +5,7 @@ import (
"math/rand"
"net"
"os"
"strconv"
"strings"
"testing"
@ -965,6 +966,114 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
}
}
func BenchmarkFirewallStats(b *testing.B) {
scenarios := []struct {
name string
stats bool
longLived bool
conns int
}{
{"nostats_short_100", false, false, 100},
{"stats_short_100", true, false, 100},
{"nostats_long_100", false, true, 100},
{"stats_long_100", true, true, 100},
{"nostats_short_1000", false, false, 1000},
{"stats_short_1000", true, false, 1000},
{"nostats_long_1000", false, true, 1000},
{"stats_long_1000", true, true, 1000},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
b.Setenv(EnvEnableStats, strconv.FormatBool(sc.stats))
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Generate test IPs
srcIPs := make([]net.IP, sc.conns)
dstIPs := make([]net.IP, sc.conns)
for i := 0; i < sc.conns; i++ {
srcIPs[i] = generateRandomIPs(1)[0]
dstIPs[i] = generateRandomIPs(1)[0]
}
// Pre-generate packets
inPackets := make([][]byte, sc.conns)
outPackets := make([][]byte, sc.conns)
for i := 0; i < sc.conns; i++ {
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
if sc.longLived {
// Establish connection
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
synAck := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(syn)
manager.dropFilter(synAck, manager.incomingRules)
manager.processOutgoingHooks(ack)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
connIdx := i % sc.conns
if !sc.longLived {
// New connection each time
syn := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
uint16(1024+connIdx), 80, uint16(conntrack.TCPSyn))
synAck := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx],
80, uint16(1024+connIdx), uint16(conntrack.TCPSyn|conntrack.TCPAck))
ack := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
uint16(1024+connIdx), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(syn)
manager.dropFilter(synAck, manager.incomingRules)
manager.processOutgoingHooks(ack)
}
// Data transfer
manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
if !sc.longLived {
// Tear down
finClient := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
uint16(1024+connIdx), 80, uint16(conntrack.TCPFin|conntrack.TCPAck))
ackServer := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx],
80, uint16(1024+connIdx), uint16(conntrack.TCPAck))
finServer := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx],
80, uint16(1024+connIdx), uint16(conntrack.TCPFin|conntrack.TCPAck))
ackClient := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
uint16(1024+connIdx), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(finClient)
manager.dropFilter(ackServer, manager.incomingRules)
manager.dropFilter(finServer, manager.incomingRules)
manager.processOutgoingHooks(ackClient)
}
}
})
}
}
// generateTCPPacketWithFlags creates a TCP packet with specific flags
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
b.Helper()

View File

@ -400,7 +400,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil)
defer func() {
require.NoError(t, manager.Reset(nil))
}()
@ -518,7 +518,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil)
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{

View File

@ -23,7 +23,7 @@ import (
"google.golang.org/protobuf/proto"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
firewallmanager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
@ -158,7 +158,7 @@ type Engine struct {
statusRecorder *peer.Status
firewall manager.Manager
firewall firewallmanager.Manager
routeManager routemanager.Manager
acl acl.Manager
dnsForwardMgr *dnsfwd.Manager
@ -1576,6 +1576,14 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
return nm, nil
}
// GetFirewallStats returns the firewall stats
func (e *Engine) GetFirewallStats() []*firewallmanager.FlowStats {
if e.firewall != nil {
return e.firewall.CollectStats()
}
return nil
}
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
if !enabled {

View File

@ -44,6 +44,7 @@ iptables.txt: Anonymized iptables rules with packet counters, if --system-info f
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
firewall_stats.json: Anonymized firewall statistics of the NetBird client.
state.json: Anonymized client state dump containing netbird states.
@ -139,10 +140,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
s.mutex.Lock()
defer s.mutex.Unlock()
if s.logFile == "console" {
return nil, fmt.Errorf("log file is set to console, cannot create debug bundle")
}
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil {
return nil, fmt.Errorf("create zip file: %w", err)
@ -202,6 +199,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
return fmt.Errorf("add network map: %w", err)
}
if err := s.addFirewallStats(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add firewall stats to debug bundle: %v", err)
}
if err := s.addStateFile(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add state file to debug bundle: %v", err)
}
@ -356,6 +357,43 @@ func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonym
return nil
}
func (s *Server) addFirewallStats(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
if s.connectClient == nil || s.connectClient.Engine() == nil {
return nil
}
stats := s.connectClient.Engine().GetFirewallStats()
if stats == nil {
return nil
}
if req.GetAnonymize() {
for _, stat := range stats {
if stat.SourceIP != nil {
if ip, ok := netip.AddrFromSlice(stat.SourceIP); ok {
stat.SourceIP = anonymizer.AnonymizeIP(ip).AsSlice()
}
}
if stat.DestIP != nil {
if ip, ok := netip.AddrFromSlice(stat.DestIP); ok {
stat.DestIP = anonymizer.AnonymizeIP(ip).AsSlice()
}
}
}
}
jsonBytes, err := json.MarshalIndent(stats, "", " ")
if err != nil {
return fmt.Errorf("marshal firewall stats: %w", err)
}
if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "firewall_stats.json"); err != nil {
return fmt.Errorf("add firewall stats to zip: %w", err)
}
return nil
}
func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
path := statemanager.GetDefaultStatePath()
if path == "" {