mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-08 14:29:39 +01:00
Add stats collection
This commit is contained in:
parent
ad9f044aad
commit
f57bc604a8
@ -215,6 +215,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// CollectStats returns connection tracking statistics
|
||||
func (m *Manager) CollectStats() []*firewall.FlowStats {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
@ -100,6 +100,9 @@ type Manager interface {
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
// CollectStats returns the statistics of the firewall manager
|
||||
CollectStats() []*FlowStats
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
107
client/firewall/manager/stats.go
Normal file
107
client/firewall/manager/stats.go
Normal file
@ -0,0 +1,107 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DirectionInbound Direction = 0
|
||||
DirectionOutbound Direction = 1
|
||||
)
|
||||
|
||||
type Direction uint8
|
||||
|
||||
func (d Direction) String() string {
|
||||
switch d {
|
||||
case DirectionInbound:
|
||||
return "inbound"
|
||||
case DirectionOutbound:
|
||||
return "outbound"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// FlowStats tracks statistics for an individual connection
|
||||
type FlowStats struct {
|
||||
StartTime time.Time
|
||||
LastSeen time.Time
|
||||
BytesIn atomic.Uint64
|
||||
BytesOut atomic.Uint64
|
||||
PacketsIn atomic.Uint64
|
||||
PacketsOut atomic.Uint64
|
||||
Protocol uint8
|
||||
Direction Direction
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
}
|
||||
|
||||
func (f *FlowStats) Clone() *FlowStats {
|
||||
flowCopy := FlowStats{
|
||||
StartTime: f.StartTime,
|
||||
LastSeen: f.LastSeen,
|
||||
Protocol: f.Protocol,
|
||||
Direction: f.Direction,
|
||||
SourceIP: slices.Clone(f.SourceIP),
|
||||
DestIP: slices.Clone(f.DestIP),
|
||||
SourcePort: f.SourcePort,
|
||||
DestPort: f.DestPort,
|
||||
}
|
||||
flowCopy.BytesIn.Store(f.BytesIn.Load())
|
||||
flowCopy.BytesOut.Store(f.BytesOut.Load())
|
||||
flowCopy.PacketsIn.Store(f.PacketsIn.Load())
|
||||
flowCopy.PacketsOut.Store(f.PacketsOut.Load())
|
||||
|
||||
return &flowCopy
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler interface
|
||||
func (f *FlowStats) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(&struct {
|
||||
StartTime time.Time
|
||||
LastSeen time.Time
|
||||
BytesIn uint64
|
||||
BytesOut uint64
|
||||
PacketsIn uint64
|
||||
PacketsOut uint64
|
||||
Protocol Protocol
|
||||
Direction string
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
}{
|
||||
StartTime: f.StartTime,
|
||||
LastSeen: f.LastSeen,
|
||||
BytesIn: f.BytesIn.Load(),
|
||||
BytesOut: f.BytesOut.Load(),
|
||||
PacketsIn: f.PacketsIn.Load(),
|
||||
PacketsOut: f.PacketsOut.Load(),
|
||||
Protocol: protoFromInt(f.Protocol),
|
||||
Direction: f.Direction.String(),
|
||||
SourceIP: f.SourceIP,
|
||||
DestIP: f.DestIP,
|
||||
SourcePort: f.SourcePort,
|
||||
DestPort: f.DestPort,
|
||||
})
|
||||
}
|
||||
|
||||
func protoFromInt(p uint8) Protocol {
|
||||
switch p {
|
||||
case 6:
|
||||
return ProtocolTCP
|
||||
case 17:
|
||||
return ProtocolUDP
|
||||
case 1:
|
||||
return ProtocolICMP
|
||||
default:
|
||||
return Protocol(strconv.Itoa(int(p)))
|
||||
}
|
||||
}
|
@ -323,6 +323,11 @@ func (m *Manager) Flush() error {
|
||||
return m.aclManager.Flush()
|
||||
}
|
||||
|
||||
// CollectStats returns connection tracking statistics
|
||||
func (m *Manager) CollectStats() []*firewall.FlowStats {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
@ -17,17 +17,17 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, nil)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, nil)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, nil)
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
|
@ -29,17 +29,17 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, nil)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, nil)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, nil)
|
||||
}
|
||||
|
||||
if !isWindowsFirewallReachable() {
|
||||
|
@ -64,7 +64,7 @@ func BenchmarkAtomicOperations(b *testing.B) {
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
@ -79,17 +79,17 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, nil)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
@ -104,11 +104,11 @@ func BenchmarkMemoryPressure(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, nil)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -39,10 +41,11 @@ type ICMPTracker struct {
|
||||
mutex sync.RWMutex
|
||||
done chan struct{}
|
||||
ipPool *PreallocatedIPs
|
||||
stats *Stats
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
func NewICMPTracker(timeout time.Duration, stats *Stats) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultICMPTimeout
|
||||
}
|
||||
@ -53,6 +56,7 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
stats: stats,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
@ -60,7 +64,7 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker {
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, packetData []byte) {
|
||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
@ -83,14 +87,22 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
t.connections[key] = conn
|
||||
|
||||
if t.stats != nil {
|
||||
t.stats.TrackNewConnection(1, srcIP, dstIP, 0, 0, fw.DirectionOutbound)
|
||||
}
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
if t.stats != nil {
|
||||
key := makeConnKey(srcIP, dstIP, 0, 0)
|
||||
t.stats.TrackPacket(1, false, uint64(len(packetData)), false, key)
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8, packetData []byte) bool {
|
||||
switch icmpType {
|
||||
case uint8(layers.ICMPv4TypeDestinationUnreachable),
|
||||
uint8(layers.ICMPv4TypeTimeExceeded):
|
||||
@ -115,6 +127,11 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
|
||||
return false
|
||||
}
|
||||
|
||||
if t.stats != nil {
|
||||
key := makeConnKey(srcIP, dstIP, 0, 0)
|
||||
t.stats.TrackPacket(1, false, uint64(len(packetData)), true, key)
|
||||
}
|
||||
|
||||
return conn.IsEstablished() &&
|
||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
func BenchmarkICMPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -15,12 +15,12 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), nil)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -28,12 +28,12 @@ func BenchmarkICMPTracker(b *testing.B) {
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), nil)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0, nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
172
client/firewall/uspfilter/conntrack/stats.go
Normal file
172
client/firewall/uspfilter/conntrack/stats.go
Normal file
@ -0,0 +1,172 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Stats represents connection tracking statistics
|
||||
type Stats struct {
|
||||
TotalConnsCreated atomic.Uint64
|
||||
TotalConnsTimedOut atomic.Uint64
|
||||
TotalPacketsDropped atomic.Uint64
|
||||
ActiveConns atomic.Int64
|
||||
|
||||
TCPConns atomic.Int64
|
||||
UDPConns atomic.Int64
|
||||
ICMPConns atomic.Int64
|
||||
|
||||
TCPStateStats struct {
|
||||
SynReceived atomic.Uint64
|
||||
Established atomic.Uint64
|
||||
FinWait atomic.Uint64
|
||||
TimeWait atomic.Uint64
|
||||
InvalidStates atomic.Uint64
|
||||
}
|
||||
|
||||
PacketStats struct {
|
||||
TCPPackets atomic.Uint64
|
||||
UDPPackets atomic.Uint64
|
||||
ICMPPackets atomic.Uint64
|
||||
}
|
||||
|
||||
flowMutex sync.RWMutex
|
||||
flows map[ConnKey]*fw.FlowStats
|
||||
}
|
||||
|
||||
// NewStats creates a new Stats instance
|
||||
func NewStats() *Stats {
|
||||
return &Stats{
|
||||
flows: make(map[ConnKey]*fw.FlowStats),
|
||||
}
|
||||
}
|
||||
|
||||
// TrackNewConnection records a new connection
|
||||
func (s *Stats) TrackNewConnection(proto uint8, srcIP net.IP, dstIP net.IP, srcPort, dstPort uint16, direction fw.Direction) {
|
||||
s.TotalConnsCreated.Add(1)
|
||||
s.ActiveConns.Add(1)
|
||||
|
||||
switch proto {
|
||||
case 6: // TCP
|
||||
s.TCPConns.Add(1)
|
||||
case 17: // UDP
|
||||
s.UDPConns.Add(1)
|
||||
case 1: // ICMP
|
||||
s.ICMPConns.Add(1)
|
||||
}
|
||||
|
||||
flow := &fw.FlowStats{
|
||||
StartTime: time.Now(),
|
||||
LastSeen: time.Now(),
|
||||
Protocol: proto,
|
||||
Direction: direction,
|
||||
SourceIP: slices.Clone(srcIP),
|
||||
DestIP: slices.Clone(dstIP),
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
s.flowMutex.Lock()
|
||||
s.flows[key] = flow
|
||||
s.flowMutex.Unlock()
|
||||
}
|
||||
|
||||
// TrackConnectionClosed records a connection closure
|
||||
func (s *Stats) TrackConnectionClosed(proto uint8, timedOut bool, key ConnKey) {
|
||||
s.ActiveConns.Add(-1)
|
||||
|
||||
if timedOut {
|
||||
s.TotalConnsTimedOut.Add(1)
|
||||
}
|
||||
|
||||
switch proto {
|
||||
case 6: // TCP
|
||||
s.TCPConns.Add(-1)
|
||||
case 17: // UDP
|
||||
s.UDPConns.Add(-1)
|
||||
case 1: // ICMP
|
||||
s.ICMPConns.Add(-1)
|
||||
}
|
||||
|
||||
s.flowMutex.Lock()
|
||||
delete(s.flows, key)
|
||||
s.flowMutex.Unlock()
|
||||
}
|
||||
|
||||
// TrackPacket records packet statistics
|
||||
func (s *Stats) TrackPacket(proto uint8, dropped bool, bytes uint64, isInbound bool, key ConnKey) {
|
||||
if dropped {
|
||||
s.TotalPacketsDropped.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
switch proto {
|
||||
case 6: // TCP
|
||||
s.PacketStats.TCPPackets.Add(1)
|
||||
case 17: // UDP
|
||||
s.PacketStats.UDPPackets.Add(1)
|
||||
case 1: // ICMP
|
||||
s.PacketStats.ICMPPackets.Add(1)
|
||||
}
|
||||
|
||||
s.flowMutex.RLock()
|
||||
if flow, exists := s.flows[key]; exists {
|
||||
if isInbound {
|
||||
flow.BytesIn.Add(bytes)
|
||||
flow.PacketsIn.Add(1)
|
||||
} else {
|
||||
flow.BytesOut.Add(bytes)
|
||||
flow.PacketsOut.Add(1)
|
||||
}
|
||||
flow.LastSeen = time.Now()
|
||||
}
|
||||
s.flowMutex.RUnlock()
|
||||
}
|
||||
|
||||
// TrackTCPState updates TCP state statistics
|
||||
func (s *Stats) TrackTCPState(newState TCPState) {
|
||||
switch newState {
|
||||
case TCPStateSynReceived:
|
||||
s.TCPStateStats.SynReceived.Add(1)
|
||||
case TCPStateEstablished:
|
||||
s.TCPStateStats.Established.Add(1)
|
||||
case TCPStateFinWait1, TCPStateFinWait2:
|
||||
s.TCPStateStats.FinWait.Add(1)
|
||||
case TCPStateTimeWait:
|
||||
s.TCPStateStats.TimeWait.Add(1)
|
||||
default:
|
||||
s.TCPStateStats.InvalidStates.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// GetFlowSnapshot returns a copy of current flow statistics if enabled
|
||||
func (s *Stats) GetFlowSnapshot() []*fw.FlowStats {
|
||||
s.flowMutex.RLock()
|
||||
defer s.flowMutex.RUnlock()
|
||||
|
||||
snapshot := make([]*fw.FlowStats, 0, len(s.flows))
|
||||
for _, flow := range s.flows {
|
||||
snapshot = append(snapshot, flow.Clone())
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
// CleanupFlows removes flow entries older than the specified duration if enabled
|
||||
func (s *Stats) CleanupFlows(maxAge time.Duration) {
|
||||
threshold := time.Now().Add(-maxAge)
|
||||
|
||||
s.flowMutex.Lock()
|
||||
defer s.flowMutex.Unlock()
|
||||
|
||||
for key, flow := range s.flows {
|
||||
if flow.LastSeen.Before(threshold) {
|
||||
delete(s.flows, key)
|
||||
}
|
||||
}
|
||||
}
|
@ -6,6 +6,8 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -72,16 +74,18 @@ type TCPTracker struct {
|
||||
done chan struct{}
|
||||
timeout time.Duration
|
||||
ipPool *PreallocatedIPs
|
||||
stats *Stats
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||
func NewTCPTracker(timeout time.Duration, stats *Stats) *TCPTracker {
|
||||
tracker := &TCPTracker{
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
timeout: timeout,
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
stats: stats,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
@ -89,15 +93,13 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker {
|
||||
}
|
||||
|
||||
// TrackOutbound processes an outbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
// Create key before lock
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, packetData []byte) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
t.mutex.Lock()
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
// Use preallocated IPs
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
@ -115,18 +117,30 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(false)
|
||||
t.connections[key] = conn
|
||||
|
||||
if t.stats != nil {
|
||||
t.stats.TrackNewConnection(6, srcIP, dstIP, srcPort, dstPort, fw.DirectionOutbound)
|
||||
}
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
// Lock individual connection for state update
|
||||
conn.Lock()
|
||||
oldState := conn.State
|
||||
t.updateState(conn, flags, true)
|
||||
if oldState != conn.State && t.stats != nil {
|
||||
t.stats.TrackTCPState(conn.State)
|
||||
}
|
||||
conn.Unlock()
|
||||
|
||||
if t.stats != nil {
|
||||
t.stats.TrackPacket(6, false, uint64(len(packetData)), false, key)
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, packetData []byte) bool {
|
||||
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
@ -156,6 +170,11 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
t.connections[key] = conn
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
if t.stats != nil {
|
||||
t.stats.TrackPacket(6, false, uint64(len(packetData)), true, key)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@ -169,6 +188,10 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
return false
|
||||
}
|
||||
|
||||
if t.stats != nil {
|
||||
t.stats.TrackPacket(6, false, uint64(len(packetData)), true, key)
|
||||
}
|
||||
|
||||
// Handle RST packets
|
||||
if flags&TCPRst != 0 {
|
||||
conn.Lock()
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, nil)
|
||||
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||
})
|
||||
}
|
||||
@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// Send initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
|
||||
|
||||
// Receive SYN-ACK
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
// Send ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
|
||||
|
||||
// Test data transfer
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, nil)
|
||||
require.True(t, valid, "Data should be allowed after handshake")
|
||||
},
|
||||
},
|
||||
@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, nil)
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, nil)
|
||||
require.True(t, valid, "ACK for FIN should be allowed")
|
||||
|
||||
// Receive FIN from other side
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, nil)
|
||||
require.True(t, valid, "FIN should be allowed")
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -122,11 +122,11 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Receive RST
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
|
||||
require.True(t, valid, "RST should be allowed for established connection")
|
||||
|
||||
// Verify connection is closed
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, nil)
|
||||
t.Helper()
|
||||
|
||||
require.False(t, valid, "Data should be blocked after RST")
|
||||
@ -141,13 +141,13 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Both sides send FIN+ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, nil)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, nil)
|
||||
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||
|
||||
// Both sides send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, nil)
|
||||
require.True(t, valid, "Final ACKs should be allowed")
|
||||
},
|
||||
},
|
||||
@ -157,7 +157,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
tt.test(t)
|
||||
})
|
||||
}
|
||||
@ -165,7 +165,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
@ -184,12 +184,12 @@ func TestRSTHandling(t *testing.T) {
|
||||
name: "RST in established",
|
||||
setupState: func() {
|
||||
// Establish connection first
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
|
||||
},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
|
||||
},
|
||||
wantValid: true,
|
||||
desc: "Should accept RST for established connection",
|
||||
@ -198,7 +198,7 @@ func TestRSTHandling(t *testing.T) {
|
||||
name: "RST without connection",
|
||||
setupState: func() {},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, nil)
|
||||
},
|
||||
wantValid: false,
|
||||
desc: "Should reject RST without connection",
|
||||
@ -226,17 +226,17 @@ func TestRSTHandling(t *testing.T) {
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
||||
t.Helper()
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, nil)
|
||||
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, nil)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, nil)
|
||||
}
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -244,12 +244,12 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, nil)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -257,17 +257,17 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, nil)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, nil)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -277,9 +277,9 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, nil)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, nil)
|
||||
}
|
||||
i++
|
||||
}
|
||||
@ -290,14 +290,14 @@ func BenchmarkTCPTracker(b *testing.B) {
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
|
||||
tracker := NewTCPTracker(100*time.Millisecond, nil) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, nil)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -26,10 +28,11 @@ type UDPTracker struct {
|
||||
mutex sync.RWMutex
|
||||
done chan struct{}
|
||||
ipPool *PreallocatedIPs
|
||||
stats *Stats
|
||||
}
|
||||
|
||||
// NewUDPTracker creates a new UDP connection tracker
|
||||
func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
func NewUDPTracker(timeout time.Duration, stats *Stats) *UDPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultUDPTimeout
|
||||
}
|
||||
@ -40,6 +43,7 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
stats: stats,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine()
|
||||
@ -47,7 +51,7 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker {
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, packetData []byte) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
@ -70,14 +74,21 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d
|
||||
conn.lastSeen.Store(now)
|
||||
conn.established.Store(true)
|
||||
t.connections[key] = conn
|
||||
|
||||
if t.stats != nil {
|
||||
t.stats.TrackNewConnection(17, srcIP, dstIP, srcPort, dstPort, fw.DirectionOutbound)
|
||||
}
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
if t.stats != nil {
|
||||
t.stats.TrackPacket(17, false, uint64(len(packetData)), false, key)
|
||||
}
|
||||
conn.lastSeen.Store(now)
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, packetData []byte) bool {
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
@ -92,6 +103,10 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
return false
|
||||
}
|
||||
|
||||
if t.stats != nil {
|
||||
t.stats.TrackPacket(17, false, uint64(len(packetData)), true, key)
|
||||
}
|
||||
|
||||
return conn.IsEstablished() &&
|
||||
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
|
@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tracker := NewUDPTracker(tt.timeout)
|
||||
tracker := NewUDPTracker(tt.timeout, nil)
|
||||
assert.NotNil(t, tracker)
|
||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
@ -48,7 +48,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, nil)
|
||||
|
||||
// Verify connection was tracked
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
@ -63,7 +63,7 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1 * time.Second)
|
||||
tracker := NewUDPTracker(1*time.Second, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
@ -72,7 +72,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
dstPort := uint16(53)
|
||||
|
||||
// Track outbound connection
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -144,7 +144,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
if tt.sleep > 0 {
|
||||
time.Sleep(tt.sleep)
|
||||
}
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, nil)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
@ -189,7 +189,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, conn := range connections {
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, nil)
|
||||
}
|
||||
|
||||
// Verify initial connections
|
||||
@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
|
||||
func BenchmarkUDPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -219,12 +219,12 @@ func BenchmarkUDPTracker(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, nil)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, nil)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
@ -232,12 +232,12 @@ func BenchmarkUDPTracker(b *testing.B) {
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, nil)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -22,7 +22,10 @@ import (
|
||||
|
||||
const layerTypeAll = 0
|
||||
|
||||
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
const (
|
||||
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
EnvEnableStats = "NB_ENABLE_CONNTRACK_STATS"
|
||||
)
|
||||
|
||||
var (
|
||||
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
|
||||
@ -52,6 +55,9 @@ type Manager struct {
|
||||
udpTracker *conntrack.UDPTracker
|
||||
icmpTracker *conntrack.ICMPTracker
|
||||
tcpTracker *conntrack.TCPTracker
|
||||
|
||||
statsEnabled bool
|
||||
stats *conntrack.Stats
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@ -84,6 +90,7 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
|
||||
|
||||
func create(iface IFaceMapper) (*Manager, error) {
|
||||
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
||||
enableStats, _ := strconv.ParseBool(os.Getenv(EnvEnableStats))
|
||||
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
@ -103,15 +110,21 @@ func create(iface IFaceMapper) (*Manager, error) {
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
wgIface: iface,
|
||||
stateful: !disableConntrack,
|
||||
statsEnabled: enableStats,
|
||||
}
|
||||
|
||||
if enableStats {
|
||||
m.stats = conntrack.NewStats()
|
||||
log.Info("connection tracking statistics enabled")
|
||||
}
|
||||
|
||||
// Only initialize trackers if stateful mode is enabled
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.stats)
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.stats)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.stats)
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
@ -304,7 +317,10 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
if d.decoded[1] == layers.LayerTypeUDP {
|
||||
// Track UDP state only if enabled
|
||||
if m.stateful {
|
||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
||||
m.udpTracker.TrackOutbound(srcIP, dstIP,
|
||||
uint16(d.udp.SrcPort),
|
||||
uint16(d.udp.DstPort),
|
||||
packetData)
|
||||
}
|
||||
return m.checkUDPHooks(d, dstIP, packetData)
|
||||
}
|
||||
@ -313,9 +329,16 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||
if m.stateful {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP,
|
||||
uint16(d.tcp.SrcPort),
|
||||
uint16(d.tcp.DstPort),
|
||||
getTCPFlags(&d.tcp),
|
||||
packetData)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.trackICMPOutbound(d, srcIP, dstIP)
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP,
|
||||
d.icmp4.Id,
|
||||
d.icmp4.Seq,
|
||||
packetData)
|
||||
}
|
||||
}
|
||||
|
||||
@ -333,17 +356,6 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
uint16(d.tcp.SrcPort),
|
||||
uint16(d.tcp.DstPort),
|
||||
flags,
|
||||
)
|
||||
}
|
||||
|
||||
func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
var flags uint8
|
||||
if tcp.SYN {
|
||||
@ -367,15 +379,6 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
m.udpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
uint16(d.udp.SrcPort),
|
||||
uint16(d.udp.DstPort),
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
||||
@ -389,17 +392,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
||||
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
||||
m.icmpTracker.TrackOutbound(
|
||||
srcIP,
|
||||
dstIP,
|
||||
d.icmp4.Id,
|
||||
d.icmp4.Seq,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// dropFilter implements filtering logic for incoming packets
|
||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
m.mutex.RLock()
|
||||
@ -423,7 +415,7 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
||||
}
|
||||
|
||||
// Check connection state only if enabled
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, packetData) {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -447,7 +439,7 @@ func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
|
||||
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return m.tcpTracker.IsValidInbound(
|
||||
@ -456,6 +448,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
uint16(d.tcp.SrcPort),
|
||||
uint16(d.tcp.DstPort),
|
||||
getTCPFlags(&d.tcp),
|
||||
packetData,
|
||||
)
|
||||
|
||||
case layers.LayerTypeUDP:
|
||||
@ -464,6 +457,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
dstIP,
|
||||
uint16(d.udp.SrcPort),
|
||||
uint16(d.udp.DstPort),
|
||||
packetData,
|
||||
)
|
||||
|
||||
case layers.LayerTypeICMPv4:
|
||||
@ -473,6 +467,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
||||
d.icmp4.Id,
|
||||
d.icmp4.Seq,
|
||||
d.icmp4.TypeCode.Type(),
|
||||
packetData,
|
||||
)
|
||||
|
||||
// TODO: ICMPv6
|
||||
@ -612,3 +607,11 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
}
|
||||
|
||||
// CollectStats returns connection tracking statistics
|
||||
func (m *Manager) CollectStats() []*firewall.FlowStats {
|
||||
if m.stats == nil {
|
||||
return nil
|
||||
}
|
||||
return m.stats.GetFlowSnapshot()
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -965,6 +966,114 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFirewallStats(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
stats bool
|
||||
longLived bool
|
||||
conns int
|
||||
}{
|
||||
{"nostats_short_100", false, false, 100},
|
||||
{"stats_short_100", true, false, 100},
|
||||
{"nostats_long_100", false, true, 100},
|
||||
{"stats_long_100", true, true, 100},
|
||||
{"nostats_short_1000", false, false, 1000},
|
||||
{"stats_short_1000", true, false, 1000},
|
||||
{"nostats_long_1000", false, true, 1000},
|
||||
{"stats_long_1000", true, true, 1000},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
})
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
b.Setenv(EnvEnableStats, strconv.FormatBool(sc.stats))
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Generate test IPs
|
||||
srcIPs := make([]net.IP, sc.conns)
|
||||
dstIPs := make([]net.IP, sc.conns)
|
||||
for i := 0; i < sc.conns; i++ {
|
||||
srcIPs[i] = generateRandomIPs(1)[0]
|
||||
dstIPs[i] = generateRandomIPs(1)[0]
|
||||
}
|
||||
|
||||
// Pre-generate packets
|
||||
inPackets := make([][]byte, sc.conns)
|
||||
outPackets := make([][]byte, sc.conns)
|
||||
for i := 0; i < sc.conns; i++ {
|
||||
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||
|
||||
if sc.longLived {
|
||||
// Establish connection
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
synAck := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
|
||||
manager.processOutgoingHooks(syn)
|
||||
manager.dropFilter(synAck, manager.incomingRules)
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
connIdx := i % sc.conns
|
||||
|
||||
if !sc.longLived {
|
||||
// New connection each time
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
|
||||
uint16(1024+connIdx), 80, uint16(conntrack.TCPSyn))
|
||||
synAck := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx],
|
||||
80, uint16(1024+connIdx), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
|
||||
uint16(1024+connIdx), 80, uint16(conntrack.TCPAck))
|
||||
|
||||
manager.processOutgoingHooks(syn)
|
||||
manager.dropFilter(synAck, manager.incomingRules)
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
|
||||
// Data transfer
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
|
||||
|
||||
if !sc.longLived {
|
||||
// Tear down
|
||||
finClient := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
|
||||
uint16(1024+connIdx), 80, uint16(conntrack.TCPFin|conntrack.TCPAck))
|
||||
ackServer := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx],
|
||||
80, uint16(1024+connIdx), uint16(conntrack.TCPAck))
|
||||
finServer := generateTCPPacketWithFlags(b, dstIPs[connIdx], srcIPs[connIdx],
|
||||
80, uint16(1024+connIdx), uint16(conntrack.TCPFin|conntrack.TCPAck))
|
||||
ackClient := generateTCPPacketWithFlags(b, srcIPs[connIdx], dstIPs[connIdx],
|
||||
uint16(1024+connIdx), 80, uint16(conntrack.TCPAck))
|
||||
|
||||
manager.processOutgoingHooks(finClient)
|
||||
manager.dropFilter(ackServer, manager.incomingRules)
|
||||
manager.dropFilter(finServer, manager.incomingRules)
|
||||
manager.processOutgoingHooks(ackClient)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateTCPPacketWithFlags creates a TCP packet with specific flags
|
||||
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
|
||||
b.Helper()
|
||||
|
@ -400,7 +400,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, nil)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
@ -518,7 +518,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, nil)
|
||||
manager.decoders = sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
|
@ -23,7 +23,7 @@ import (
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
firewallmanager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@ -158,7 +158,7 @@ type Engine struct {
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
firewall manager.Manager
|
||||
firewall firewallmanager.Manager
|
||||
routeManager routemanager.Manager
|
||||
acl acl.Manager
|
||||
dnsForwardMgr *dnsfwd.Manager
|
||||
@ -1576,6 +1576,14 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||
return nm, nil
|
||||
}
|
||||
|
||||
// GetFirewallStats returns the firewall stats
|
||||
func (e *Engine) GetFirewallStats() []*firewallmanager.FlowStats {
|
||||
if e.firewall != nil {
|
||||
return e.firewall.CollectStats()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||
if !enabled {
|
||||
|
@ -44,6 +44,7 @@ iptables.txt: Anonymized iptables rules with packet counters, if --system-info f
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
firewall_stats.json: Anonymized firewall statistics of the NetBird client.
|
||||
state.json: Anonymized client state dump containing netbird states.
|
||||
|
||||
|
||||
@ -139,10 +140,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.logFile == "console" {
|
||||
return nil, fmt.Errorf("log file is set to console, cannot create debug bundle")
|
||||
}
|
||||
|
||||
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create zip file: %w", err)
|
||||
@ -202,6 +199,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
|
||||
return fmt.Errorf("add network map: %w", err)
|
||||
}
|
||||
|
||||
if err := s.addFirewallStats(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add firewall stats to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := s.addStateFile(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add state file to debug bundle: %v", err)
|
||||
}
|
||||
@ -356,6 +357,43 @@ func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonym
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addFirewallStats(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
if s.connectClient == nil || s.connectClient.Engine() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
stats := s.connectClient.Engine().GetFirewallStats()
|
||||
if stats == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if req.GetAnonymize() {
|
||||
for _, stat := range stats {
|
||||
if stat.SourceIP != nil {
|
||||
if ip, ok := netip.AddrFromSlice(stat.SourceIP); ok {
|
||||
stat.SourceIP = anonymizer.AnonymizeIP(ip).AsSlice()
|
||||
}
|
||||
}
|
||||
if stat.DestIP != nil {
|
||||
if ip, ok := netip.AddrFromSlice(stat.DestIP); ok {
|
||||
stat.DestIP = anonymizer.AnonymizeIP(ip).AsSlice()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonBytes, err := json.MarshalIndent(stats, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal firewall stats: %w", err)
|
||||
}
|
||||
|
||||
if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "firewall_stats.json"); err != nil {
|
||||
return fmt.Errorf("add firewall stats to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
path := statemanager.GetDefaultStatePath()
|
||||
if path == "" {
|
||||
|
Loading…
Reference in New Issue
Block a user