From fa748a7ec267079aade6886ad82bd2ab3a13a5ef Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 28 Feb 2025 11:08:35 +0100 Subject: [PATCH] Add userspace flow implementation (#3393) --- client/firewall/create.go | 5 +- client/firewall/create_linux.go | 11 +- client/firewall/uspfilter/allow_netbird.go | 6 +- .../uspfilter/allow_netbird_windows.go | 6 +- client/firewall/uspfilter/conntrack/common.go | 103 ++------ .../uspfilter/conntrack/common_test.go | 37 +-- client/firewall/uspfilter/conntrack/icmp.go | 126 +++++---- .../firewall/uspfilter/conntrack/icmp_test.go | 4 +- client/firewall/uspfilter/conntrack/tcp.go | 239 ++++++++++++------ .../firewall/uspfilter/conntrack/tcp_test.go | 14 +- client/firewall/uspfilter/conntrack/udp.go | 109 +++++--- .../firewall/uspfilter/conntrack/udp_test.go | 25 +- .../firewall/uspfilter/forwarder/forwarder.go | 7 +- client/firewall/uspfilter/forwarder/icmp.go | 27 ++ client/firewall/uspfilter/forwarder/tcp.go | 28 +- client/firewall/uspfilter/forwarder/udp.go | 68 ++++- client/firewall/uspfilter/log/log.go | 190 +++++++++----- client/firewall/uspfilter/log/log_test.go | 121 +++++++++ client/firewall/uspfilter/log/ringbuffer.go | 85 ------- client/firewall/uspfilter/uspfilter.go | 90 ++++--- .../uspfilter/uspfilter_bench_test.go | 16 +- .../uspfilter/uspfilter_filter_test.go | 4 +- client/firewall/uspfilter/uspfilter_test.go | 48 ++-- client/internal/acl/manager_test.go | 8 +- client/internal/dns/server_test.go | 5 +- client/internal/engine.go | 10 +- client/internal/netflow/types/types.go | 39 ++- 27 files changed, 862 insertions(+), 569 deletions(-) create mode 100644 client/firewall/uspfilter/log/log_test.go delete mode 100644 client/firewall/uspfilter/log/ringbuffer.go diff --git a/client/firewall/create.go b/client/firewall/create.go index 37ea5ceb3..7b265e1d1 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -10,17 +10,18 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/statemanager" ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } // use userspace packet filtering firewall - fm, err := uspfilter.Create(iface, disableServerRoutes) + fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger) if err != nil { return nil, err } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index be1b37916..aa2f0d4d1 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -15,6 +15,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/uspfilter" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers @@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } - return createUserspaceFirewall(iface, fm, disableServerRoutes) + return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger) } func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { @@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) { } } -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) { var errUsp error if fm != nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes) + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger) } else { - fm, errUsp = uspfilter.Create(iface, disableServerRoutes) + fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger) } if errUsp != nil { diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 03f23f5e6..1982e4e7e 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -22,17 +22,17 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) } if m.forwarder != nil { diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 379585978..cacabe1b3 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -31,17 +31,17 @@ func (m *Manager) Reset(*statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) } if m.forwarder != nil { diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index f5f502540..915b57549 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -1,20 +1,26 @@ -// common.go package conntrack import ( + "fmt" "net" - "sync" + "net/netip" "sync/atomic" "time" + + "github.com/google/uuid" + + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) // BaseConnTrack provides common fields and locking for all connection types type BaseConnTrack struct { - SourceIP net.IP - DestIP net.IP + FlowId uuid.UUID + Direction nftypes.Direction + SourceIP netip.Addr + DestIP netip.Addr SourcePort uint16 DestPort uint16 - lastSeen atomic.Int64 // Unix nano for atomic access + lastSeen atomic.Int64 } // these small methods will be inlined by the compiler @@ -35,92 +41,27 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool { return time.Since(lastSeen) > timeout } -// IPAddr is a fixed-size IP address to avoid allocations -type IPAddr [16]byte - -// MakeIPAddr creates an IPAddr from net.IP -func MakeIPAddr(ip net.IP) (addr IPAddr) { - // Optimization: check for v4 first as it's more common - if ip4 := ip.To4(); ip4 != nil { - copy(addr[12:], ip4) - } else { - copy(addr[:], ip.To16()) - } - return addr -} - // ConnKey uniquely identifies a connection type ConnKey struct { - SrcIP IPAddr - DstIP IPAddr + SrcIP netip.Addr + DstIP netip.Addr SrcPort uint16 DstPort uint16 } +func (c ConnKey) String() string { + return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) +} + // makeConnKey creates a connection key func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { + srcAddr, _ := netip.AddrFromSlice(srcIP) + dstAddr, _ := netip.AddrFromSlice(dstIP) + return ConnKey{ - SrcIP: MakeIPAddr(srcIP), - DstIP: MakeIPAddr(dstIP), + SrcIP: srcAddr, + DstIP: dstAddr, SrcPort: srcPort, DstPort: dstPort, } } - -// ValidateIPs checks if IPs match without allocation -func ValidateIPs(connIP IPAddr, pktIP net.IP) bool { - if ip4 := pktIP.To4(); ip4 != nil { - // Compare IPv4 addresses (last 4 bytes) - for i := 0; i < 4; i++ { - if connIP[12+i] != ip4[i] { - return false - } - } - return true - } - // Compare full IPv6 addresses - ip6 := pktIP.To16() - for i := 0; i < 16; i++ { - if connIP[i] != ip6[i] { - return false - } - } - return true -} - -// PreallocatedIPs is a pool of IP byte slices to reduce allocations -type PreallocatedIPs struct { - sync.Pool -} - -// NewPreallocatedIPs creates a new IP pool -func NewPreallocatedIPs() *PreallocatedIPs { - return &PreallocatedIPs{ - Pool: sync.Pool{ - New: func() interface{} { - ip := make(net.IP, 16) - return &ip - }, - }, - } -} - -// Get retrieves an IP from the pool -func (p *PreallocatedIPs) Get() net.IP { - return *p.Pool.Get().(*net.IP) -} - -// Put returns an IP to the pool -func (p *PreallocatedIPs) Put(ip net.IP) { - p.Pool.Put(&ip) -} - -// copyIP copies an IP address efficiently -func copyIP(dst, src net.IP) { - if len(src) == 16 { - copy(dst, src) - } else { - // Handle IPv4 - copy(dst[12:], src.To4()) - } -} diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index 81fa64b19..84f1b1b75 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -1,50 +1,23 @@ package conntrack import ( + "context" "net" "testing" "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + "github.com/netbirdio/netbird/client/internal/netflow" ) var logger = log.NewFromLogrus(logrus.StandardLogger()) - -func BenchmarkIPOperations(b *testing.B) { - b.Run("MakeIPAddr", func(b *testing.B) { - ip := net.ParseIP("192.168.1.1") - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = MakeIPAddr(ip) - } - }) - - b.Run("ValidateIPs", func(b *testing.B) { - ip1 := net.ParseIP("192.168.1.1") - ip2 := net.ParseIP("192.168.1.1") - addr := MakeIPAddr(ip1) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = ValidateIPs(addr, ip2) - } - }) - - b.Run("IPPool", func(b *testing.B) { - pool := NewPreallocatedIPs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - ip := pool.Get() - pool.Put(ip) - } - }) - -} +var flowLogger = netflow.NewManager(context.Background()).GetLogger() // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout, logger) + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() // Generate different IPs @@ -69,7 +42,7 @@ func BenchmarkMemoryPressure(b *testing.B) { }) b.Run("UDPHighLoad", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout, logger) + tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) defer tracker.Close() // Generate different IPs diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 25cd9e87d..630f6d04e 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -1,13 +1,17 @@ package conntrack import ( + "fmt" "net" + "net/netip" "sync" "time" "github.com/google/gopacket/layers" + "github.com/google/uuid" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) const ( @@ -19,18 +23,19 @@ const ( // ICMPConnKey uniquely identifies an ICMP connection type ICMPConnKey struct { - // Supports both IPv4 and IPv6 - SrcIP [16]byte - DstIP [16]byte - Sequence uint16 // ICMP sequence number - ID uint16 // ICMP identifier + SrcIP netip.Addr + DstIP netip.Addr + Sequence uint16 + ID uint16 +} + +func (i *ICMPConnKey) String() string { + return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.Sequence, i.ID) } // ICMPConnTrack represents an ICMP connection state type ICMPConnTrack struct { BaseConnTrack - Sequence uint16 - ID uint16 } // ICMPTracker manages ICMP connection states @@ -41,11 +46,11 @@ type ICMPTracker struct { cleanupTicker *time.Ticker mutex sync.RWMutex done chan struct{} - ipPool *PreallocatedIPs + flowLogger nftypes.FlowLogger } // NewICMPTracker creates a new ICMP connection tracker -func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { +func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker { if timeout == 0 { timeout = DefaultICMPTimeout } @@ -56,41 +61,65 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), done: make(chan struct{}), - ipPool: NewPreallocatedIPs(), + flowLogger: flowLogger, } go tracker.cleanupRoutine() return tracker } -// TrackOutbound records an outbound ICMP Echo Request -func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { +func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) { key := makeICMPKey(srcIP, dstIP, id, seq) - t.mutex.Lock() + t.mutex.RLock() conn, exists := t.connections[key] - if !exists { - srcIPCopy := t.ipPool.Get() - dstIPCopy := t.ipPool.Get() - copyIP(srcIPCopy, srcIP) - copyIP(dstIPCopy, dstIP) + t.mutex.RUnlock() - conn = &ICMPConnTrack{ - BaseConnTrack: BaseConnTrack{ - SourceIP: srcIPCopy, - DestIP: dstIPCopy, - }, - ID: id, - Sequence: seq, - } + if exists { conn.UpdateLastSeen() - t.connections[key] = conn - t.logger.Trace("New ICMP connection %v", key) + return key, true } + + return key, false +} + +// TrackOutbound records an outbound ICMP connection +func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { + if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists { + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, id, seq, nftypes.Egress) + } +} + +// TrackInbound records an inbound ICMP Echo Request +func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { + t.track(srcIP, dstIP, id, seq, nftypes.Ingress) +} + +// track is the common implementation for tracking both inbound and outbound ICMP connections +func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, direction nftypes.Direction) { + key, exists := t.updateIfExists(srcIP, dstIP, id, seq) + if exists { + return + } + + conn := &ICMPConnTrack{ + BaseConnTrack: BaseConnTrack{ + FlowId: uuid.New(), + Direction: direction, + SourceIP: key.SrcIP, + DestIP: key.DstIP, + }, + } + conn.UpdateLastSeen() + + t.mutex.Lock() + t.connections[key] = conn t.mutex.Unlock() - conn.UpdateLastSeen() + t.logger.Trace("New %s ICMP connection %s", conn.Direction, key) + t.sendEvent(nftypes.TypeStart, key, conn) } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request @@ -105,18 +134,13 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq conn, exists := t.connections[key] t.mutex.RUnlock() - if !exists { + if !exists || conn.timeoutExceeded(t.timeout) { return false } - if conn.timeoutExceeded(t.timeout) { - return false - } + conn.UpdateLastSeen() - return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && - ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && - conn.ID == id && - conn.Sequence == seq + return true } func (t *ICMPTracker) cleanupRoutine() { @@ -129,17 +153,17 @@ func (t *ICMPTracker) cleanupRoutine() { } } } + func (t *ICMPTracker) cleanup() { t.mutex.Lock() defer t.mutex.Unlock() for key, conn := range t.connections { if conn.timeoutExceeded(t.timeout) { - t.ipPool.Put(conn.SourceIP) - t.ipPool.Put(conn.DestIP) delete(t.connections, key) - t.logger.Debug("Removed ICMP connection %v (timeout)", key) + t.logger.Debug("Removed ICMP connection %s (timeout)", &key) + t.sendEvent(nftypes.TypeEnd, key, conn) } } } @@ -150,19 +174,29 @@ func (t *ICMPTracker) Close() { close(t.done) t.mutex.Lock() - for _, conn := range t.connections { - t.ipPool.Put(conn.SourceIP) - t.ipPool.Put(conn.DestIP) - } t.connections = nil t.mutex.Unlock() } +func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) { + t.flowLogger.StoreEvent(nftypes.EventFields{ + FlowID: conn.FlowId, + Type: typ, + Direction: conn.Direction, + Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 + SourceIP: key.SrcIP, + DestIP: key.DstIP, + // TODO: add icmp code/type, + }) +} + // makeICMPKey creates an ICMP connection key func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { + srcAddr, _ := netip.AddrFromSlice(srcIP) + dstAddr, _ := netip.AddrFromSlice(dstIP) return ICMPConnKey{ - SrcIP: MakeIPAddr(srcIP), - DstIP: MakeIPAddr(dstIP), + SrcIP: srcAddr, + DstIP: dstAddr, ID: id, Sequence: seq, } diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index 32553c836..ef5317d41 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -7,7 +7,7 @@ import ( func BenchmarkICMPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout, logger) + tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout, logger) + tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 7c12e8ad0..af5ecb302 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -8,7 +8,10 @@ import ( "sync/atomic" "time" + "github.com/google/uuid" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) const ( @@ -39,6 +42,35 @@ const ( // TCPState represents the state of a TCP connection type TCPState int +func (s TCPState) String() string { + switch s { + case TCPStateNew: + return "New" + case TCPStateSynSent: + return "SYN Sent" + case TCPStateSynReceived: + return "SYN Received" + case TCPStateEstablished: + return "Established" + case TCPStateFinWait1: + return "FIN Wait 1" + case TCPStateFinWait2: + return "FIN Wait 2" + case TCPStateClosing: + return "Closing" + case TCPStateTimeWait: + return "Time Wait" + case TCPStateCloseWait: + return "Close Wait" + case TCPStateLastAck: + return "Last ACK" + case TCPStateClosed: + return "Closed" + default: + return "Unknown" + } +} + const ( TCPStateNew TCPState = iota TCPStateSynSent @@ -53,19 +85,12 @@ const ( TCPStateClosed ) -// TCPConnKey uniquely identifies a TCP connection -type TCPConnKey struct { - SrcIP [16]byte - DstIP [16]byte - SrcPort uint16 - DstPort uint16 -} - // TCPConnTrack represents a TCP connection state type TCPConnTrack struct { BaseConnTrack State TCPState established atomic.Bool + tombstone atomic.Bool sync.RWMutex } @@ -79,6 +104,16 @@ func (t *TCPConnTrack) SetEstablished(state bool) { t.established.Store(state) } +// IsTombstone safely checks if the connection is marked for deletion +func (t *TCPConnTrack) IsTombstone() bool { + return t.tombstone.Load() +} + +// SetTombstone safely marks the connection for deletion +func (t *TCPConnTrack) SetTombstone() { + t.tombstone.Store(true) +} + // TCPTracker manages TCP connection states type TCPTracker struct { logger *nblog.Logger @@ -87,68 +122,94 @@ type TCPTracker struct { cleanupTicker *time.Ticker done chan struct{} timeout time.Duration - ipPool *PreallocatedIPs + flowLogger nftypes.FlowLogger } // NewTCPTracker creates a new TCP connection tracker -func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { +func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker { + if timeout == 0 { + timeout = DefaultTCPTimeout + } + tracker := &TCPTracker{ logger: logger, connections: make(map[ConnKey]*TCPConnTrack), cleanupTicker: time.NewTicker(TCPCleanupInterval), done: make(chan struct{}), timeout: timeout, - ipPool: NewPreallocatedIPs(), + flowLogger: flowLogger, } go tracker.cleanupRoutine() return tracker } -// TrackOutbound processes an outbound TCP packet and updates connection state -func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { - // Create key before lock +func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) { key := makeConnKey(srcIP, dstIP, srcPort, dstPort) - t.mutex.Lock() + t.mutex.RLock() conn, exists := t.connections[key] - if !exists { - // Use preallocated IPs - srcIPCopy := t.ipPool.Get() - dstIPCopy := t.ipPool.Get() - copyIP(srcIPCopy, srcIP) - copyIP(dstIPCopy, dstIP) + t.mutex.RUnlock() - conn = &TCPConnTrack{ - BaseConnTrack: BaseConnTrack{ - SourceIP: srcIPCopy, - DestIP: dstIPCopy, - SourcePort: srcPort, - DestPort: dstPort, - }, - State: TCPStateNew, - } + if exists { + conn.Lock() + t.updateState(key, conn, flags, conn.Direction == nftypes.Egress) conn.UpdateLastSeen() - conn.established.Store(false) - t.connections[key] = conn + conn.Unlock() - t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort) + return key, true } + + return key, false +} + +// TrackOutbound records an outbound TCP connection +func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { + if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists { + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress) + } +} + +// TrackInbound processes an inbound TCP packet and updates connection state +func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress) +} + +// track is the common implementation for tracking both inbound and outbound connections +func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) { + key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags) + if exists { + return + } + + conn := &TCPConnTrack{ + BaseConnTrack: BaseConnTrack{ + FlowId: uuid.New(), + Direction: direction, + SourceIP: key.SrcIP, + DestIP: key.DstIP, + SourcePort: srcPort, + DestPort: dstPort, + }, + } + + conn.UpdateLastSeen() + conn.established.Store(false) + conn.tombstone.Store(false) + + t.logger.Trace("New %s TCP connection: %s", direction, key) + t.updateState(key, conn, flags, direction == nftypes.Egress) + + t.mutex.Lock() + t.connections[key] = conn t.mutex.Unlock() - // Lock individual connection for state update - conn.Lock() - t.updateState(conn, flags, true) - conn.Unlock() - conn.UpdateLastSeen() + t.sendEvent(nftypes.TypeStart, key, conn) } // IsValidInbound checks if an inbound TCP packet matches a tracked connection func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { - if !isValidFlagCombination(flags) { - return false - } - key := makeConnKey(dstIP, srcIP, dstPort, srcPort) t.mutex.RLock() @@ -159,21 +220,25 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } - // Handle RST packets + // Handle RST flag specially - it always causes transition to closed if flags&TCPRst != 0 { - conn.Lock() - if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { - conn.State = TCPStateClosed - conn.SetEstablished(false) - conn.Unlock() + if conn.IsTombstone() { return true } + + conn.Lock() + conn.SetTombstone() + conn.State = TCPStateClosed + conn.SetEstablished(false) conn.Unlock() - return false + + t.logger.Trace("TCP connection reset: %s", key) + t.sendEvent(nftypes.TypeEnd, key, conn) + return true } conn.Lock() - t.updateState(conn, flags, false) + t.updateState(key, conn, flags, false) conn.UpdateLastSeen() isEstablished := conn.IsEstablished() isValidState := t.isValidStateForFlags(conn.State, flags) @@ -183,18 +248,15 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, } // updateState updates the TCP connection state based on flags -func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) { - // Handle RST flag specially - it always causes transition to closed - if flags&TCPRst != 0 { - conn.State = TCPStateClosed - conn.SetEstablished(false) +func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) { + state := conn.State + defer func() { + if state != conn.State { + t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State) + } + }() - t.logger.Trace("TCP connection reset: %s:%d -> %s:%d", - conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) - return - } - - switch conn.State { + switch state { case TCPStateNew: if flags&TCPSyn != 0 && flags&TCPAck == 0 { conn.State = TCPStateSynSent @@ -241,6 +303,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo case TCPStateFinWait2: if flags&TCPFin != 0 { conn.State = TCPStateTimeWait + + t.logger.Trace("TCP connection %s completed", key) + t.sendEvent(nftypes.TypeEnd, key, conn) } case TCPStateClosing: @@ -248,8 +313,8 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo conn.State = TCPStateTimeWait // Keep established = false from previous state - t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d", - conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) + t.logger.Trace("TCP connection %s closed (simultaneous)", key) + t.sendEvent(nftypes.TypeEnd, key, conn) } case TCPStateCloseWait: @@ -260,17 +325,12 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo case TCPStateLastAck: if flags&TCPAck != 0 { conn.State = TCPStateClosed + conn.SetTombstone() - t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d", - conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) + // Send close event for gracefully closed connections + t.sendEvent(nftypes.TypeEnd, key, conn) + t.logger.Trace("TCP connection %s closed gracefully", key) } - - case TCPStateTimeWait: - // Stay in TIME-WAIT for 2MSL before transitioning to closed - // This is handled by the cleanup routine - - t.logger.Trace("TCP connection completed - %s:%d -> %s:%d", - conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) } } @@ -331,6 +391,12 @@ func (t *TCPTracker) cleanup() { defer t.mutex.Unlock() for key, conn := range t.connections { + if conn.IsTombstone() { + // Clean up tombstoned connections without sending an event + delete(t.connections, key) + continue + } + var timeout time.Duration switch { case conn.State == TCPStateTimeWait: @@ -341,14 +407,16 @@ func (t *TCPTracker) cleanup() { timeout = TCPHandshakeTimeout } - lastSeen := conn.GetLastSeen() - if time.Since(lastSeen) > timeout { + if conn.timeoutExceeded(timeout) { // Return IPs to pool - t.ipPool.Put(conn.SourceIP) - t.ipPool.Put(conn.DestIP) delete(t.connections, key) - t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) + t.logger.Trace("Cleaned up timed-out TCP connection %s", &key) + + // event already handled by state change + if conn.State != TCPStateTimeWait { + t.sendEvent(nftypes.TypeEnd, key, conn) + } } } } @@ -360,10 +428,6 @@ func (t *TCPTracker) Close() { // Clean up all remaining IPs t.mutex.Lock() - for _, conn := range t.connections { - t.ipPool.Put(conn.SourceIP) - t.ipPool.Put(conn.DestIP) - } t.connections = nil t.mutex.Unlock() } @@ -381,3 +445,16 @@ func isValidFlagCombination(flags uint8) bool { return true } + +func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) { + t.flowLogger.StoreEvent(nftypes.EventFields{ + FlowID: conn.FlowId, + Type: typ, + Direction: conn.Direction, + Protocol: nftypes.TCP, + SourceIP: key.SrcIP, + DestIP: key.DstIP, + SourcePort: key.SrcPort, + DestPort: key.DstPort, + }) +} diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 5f4c43915..200d77501 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -9,7 +9,7 @@ import ( ) func TestTCPStateMachine(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout, logger) + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Helper() - tracker = NewTCPTracker(DefaultTCPTimeout, logger) + tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tt.test(t) }) } @@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) { } func TestRSTHandling(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout, logger) + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, func BenchmarkTCPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout, logger) + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout, logger) + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) { }) b.Run("ConcurrentAccess", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout, logger) + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) { // Benchmark connection cleanup func BenchmarkCleanup(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) { - tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing + tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing defer tracker.Close() // Pre-populate with expired connections diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index e73465e31..922db371d 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -5,7 +5,10 @@ import ( "sync" "time" + "github.com/google/uuid" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) const ( @@ -28,11 +31,11 @@ type UDPTracker struct { cleanupTicker *time.Ticker mutex sync.RWMutex done chan struct{} - ipPool *PreallocatedIPs + flowLogger nftypes.FlowLogger } // NewUDPTracker creates a new UDP connection tracker -func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { +func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker { if timeout == 0 { timeout = DefaultUDPTimeout } @@ -43,7 +46,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), done: make(chan struct{}), - ipPool: NewPreallocatedIPs(), + flowLogger: flowLogger, } go tracker.cleanupRoutine() @@ -52,32 +55,57 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { // TrackOutbound records an outbound UDP connection func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { + if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists { + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress) + } +} + +// TrackInbound records an inbound UDP connection +func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress) +} + +func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) { key := makeConnKey(srcIP, dstIP, srcPort, dstPort) - t.mutex.Lock() + t.mutex.RLock() conn, exists := t.connections[key] - if !exists { - srcIPCopy := t.ipPool.Get() - dstIPCopy := t.ipPool.Get() - copyIP(srcIPCopy, srcIP) - copyIP(dstIPCopy, dstIP) + t.mutex.RUnlock() - conn = &UDPConnTrack{ - BaseConnTrack: BaseConnTrack{ - SourceIP: srcIPCopy, - DestIP: dstIPCopy, - SourcePort: srcPort, - DestPort: dstPort, - }, - } + if exists { conn.UpdateLastSeen() - t.connections[key] = conn - - t.logger.Trace("New UDP connection: %v", conn) + return key, true } + + return key, false +} + +// track is the common implementation for tracking both inbound and outbound connections +func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) { + key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort) + if exists { + return + } + + conn := &UDPConnTrack{ + BaseConnTrack: BaseConnTrack{ + FlowId: uuid.New(), + Direction: direction, + SourceIP: key.SrcIP, + DestIP: key.DstIP, + SourcePort: srcPort, + DestPort: dstPort, + }, + } + conn.UpdateLastSeen() + + t.mutex.Lock() + t.connections[key] = conn t.mutex.Unlock() - conn.UpdateLastSeen() + t.logger.Trace("New %s UDP connection: %s", direction, key) + t.sendEvent(nftypes.TypeStart, key, conn) } // IsValidInbound checks if an inbound packet matches a tracked connection @@ -88,18 +116,13 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, conn, exists := t.connections[key] t.mutex.RUnlock() - if !exists { + if !exists || conn.timeoutExceeded(t.timeout) { return false } - if conn.timeoutExceeded(t.timeout) { - return false - } + conn.UpdateLastSeen() - return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && - ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && - conn.DestPort == srcPort && - conn.SourcePort == dstPort + return true } // cleanupRoutine periodically removes stale connections @@ -120,11 +143,10 @@ func (t *UDPTracker) cleanup() { for key, conn := range t.connections { if conn.timeoutExceeded(t.timeout) { - t.ipPool.Put(conn.SourceIP) - t.ipPool.Put(conn.DestIP) delete(t.connections, key) - t.logger.Trace("Removed UDP connection %v (timeout)", conn) + t.logger.Trace("Removed UDP connection %s (timeout)", key) + t.sendEvent(nftypes.TypeEnd, key, conn) } } } @@ -135,10 +157,6 @@ func (t *UDPTracker) Close() { close(t.done) t.mutex.Lock() - for _, conn := range t.connections { - t.ipPool.Put(conn.SourceIP) - t.ipPool.Put(conn.DestIP) - } t.connections = nil t.mutex.Unlock() } @@ -150,14 +168,23 @@ func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, d key := makeConnKey(srcIP, dstIP, srcPort, dstPort) conn, exists := t.connections[key] - if !exists { - return nil, false - } - - return conn, true + return conn, exists } // Timeout returns the configured timeout duration for the tracker func (t *UDPTracker) Timeout() time.Duration { return t.timeout } + +func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) { + t.flowLogger.StoreEvent(nftypes.EventFields{ + FlowID: conn.FlowId, + Type: typ, + Direction: conn.Direction, + Protocol: nftypes.UDP, + SourceIP: key.SrcIP, + DestIP: key.DstIP, + SourcePort: key.SrcPort, + DestPort: key.DstPort, + }) +} diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index fa83ee356..29c7111fd 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -2,6 +2,7 @@ package conntrack import ( "net" + "net/netip" "testing" "time" @@ -29,7 +30,7 @@ func TestNewUDPTracker(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tracker := NewUDPTracker(tt.timeout, logger) + tracker := NewUDPTracker(tt.timeout, logger, flowLogger) assert.NotNil(t, tracker) assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.NotNil(t, tracker.connections) @@ -40,29 +41,29 @@ func TestNewUDPTracker(t *testing.T) { } func TestUDPTracker_TrackOutbound(t *testing.T) { - tracker := NewUDPTracker(DefaultUDPTimeout, logger) + tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("192.168.1.2") - dstIP := net.ParseIP("192.168.1.3") + srcIP := netip.MustParseAddr("192.168.1.2") + dstIP := netip.MustParseAddr("192.168.1.3") srcPort := uint16(12345) dstPort := uint16(53) - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort) // Verify connection was tracked - key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort) conn, exists := tracker.connections[key] require.True(t, exists) - assert.True(t, conn.SourceIP.Equal(srcIP)) - assert.True(t, conn.DestIP.Equal(dstIP)) + assert.True(t, conn.SourceIP.Compare(srcIP) == 0) + assert.True(t, conn.DestIP.Compare(dstIP) == 0) assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, dstPort, conn.DestPort) assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) } func TestUDPTracker_IsValidInbound(t *testing.T) { - tracker := NewUDPTracker(1*time.Second, logger) + tracker := NewUDPTracker(1*time.Second, logger, flowLogger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -160,8 +161,8 @@ func TestUDPTracker_Cleanup(t *testing.T) { timeout: timeout, cleanupTicker: time.NewTicker(cleanupInterval), done: make(chan struct{}), - ipPool: NewPreallocatedIPs(), logger: logger, + flowLogger: flowLogger, } // Start cleanup routine @@ -211,7 +212,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { func BenchmarkUDPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout, logger) + tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -224,7 +225,7 @@ func BenchmarkUDPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout, logger) + tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 4ed152b79..0dff3acc7 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/client/firewall/uspfilter/common" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) const ( @@ -29,6 +30,7 @@ const ( type Forwarder struct { logger *nblog.Logger + flowLogger nftypes.FlowLogger stack *stack.Stack endpoint *endpoint udpForwarder *udpForwarder @@ -38,7 +40,7 @@ type Forwarder struct { netstack bool } -func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) { +func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{ @@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar ctx, cancel := context.WithCancel(context.Background()) f := &Forwarder{ logger: logger, + flowLogger: flowLogger, stack: s, endpoint: endpoint, - udpForwarder: newUDPForwarder(mtu, logger), + udpForwarder: newUDPForwarder(mtu, logger, flowLogger), ctx: ctx, cancel: cancel, netstack: netstack, diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index 14cdc37be..e842ef0de 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -3,14 +3,21 @@ package forwarder import ( "context" "net" + "net/netip" "time" + "github.com/google/uuid" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) // handleICMP handles ICMP packets from the network stack func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { + flowID := uuid.New() + f.sendICMPEvent(nftypes.TypeStart, flowID, id) + ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) defer cancel() @@ -20,6 +27,8 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf if err != nil { f.logger.Error("Failed to create ICMP socket for %v: %v", id, err) + f.sendICMPEvent(nftypes.TypeEnd, flowID, id) + // This will make netstack reply on behalf of the original destination, that's ok for now return false } @@ -27,6 +36,8 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf if err := conn.Close(); err != nil { f.logger.Debug("Failed to close ICMP socket: %v", err) } + + f.sendICMPEvent(nftypes.TypeEnd, flowID, id) }() dstIP := f.determineDialAddr(id.LocalAddress) @@ -101,9 +112,25 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds if err := f.InjectIncomingPacket(fullPacket); err != nil { f.logger.Error("Failed to inject ICMP response: %v", err) + return true } f.logger.Trace("Forwarded ICMP echo reply for %v", id) return true } + +// sendICMPEvent stores flow events for ICMP packets +func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) { + f.flowLogger.StoreEvent(nftypes.EventFields{ + FlowID: flowID, + Type: typ, + Direction: nftypes.Ingress, + Protocol: 1, + // TODO: handle ipv6 + SourceIP: netip.AddrFrom4(id.LocalAddress.As4()), + DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), + SourcePort: id.LocalPort, + DestPort: id.RemotePort, + }) +} diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index 6d7cf3b6a..e48d06a69 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -5,18 +5,25 @@ import ( "fmt" "io" "net" + "net/netip" + "github.com/google/uuid" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" + + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) // handleTCP is called by the TCP forwarder for new connections. func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { id := r.ID() + flowID := uuid.New() + f.sendTCPEvent(nftypes.TypeStart, flowID, id) + dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) @@ -46,10 +53,10 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { f.logger.Trace("forwarder: established TCP connection %v", id) - go f.proxyTCP(id, inConn, outConn, ep) + go f.proxyTCP(id, inConn, outConn, ep, flowID) } -func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) { +func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { defer func() { if err := inConn.Close(); err != nil { f.logger.Debug("forwarder: inConn close error: %v", err) @@ -58,6 +65,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn f.logger.Debug("forwarder: outConn close error: %v", err) } ep.Close() + + f.sendTCPEvent(nftypes.TypeEnd, flowID, id) }() // Create context for managing the proxy goroutines @@ -88,3 +97,18 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn return } } + +func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) { + + f.flowLogger.StoreEvent(nftypes.EventFields{ + FlowID: flowID, + Type: typ, + Direction: nftypes.Ingress, + Protocol: 6, + // TODO: handle ipv6 + SourceIP: netip.AddrFrom4(id.LocalAddress.As4()), + DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), + SourcePort: id.LocalPort, + DestPort: id.RemotePort, + }) +} diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index c37740587..e3a31e26c 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -5,10 +5,12 @@ import ( "errors" "fmt" "net" + "net/netip" "sync" "sync/atomic" "time" + "github.com/google/uuid" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -16,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/waiter" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) const ( @@ -28,15 +31,17 @@ type udpPacketConn struct { lastSeen atomic.Int64 cancel context.CancelFunc ep tcpip.Endpoint + flowID uuid.UUID } type udpForwarder struct { sync.RWMutex - logger *nblog.Logger - conns map[stack.TransportEndpointID]*udpPacketConn - bufPool sync.Pool - ctx context.Context - cancel context.CancelFunc + logger *nblog.Logger + flowLogger nftypes.FlowLogger + conns map[stack.TransportEndpointID]*udpPacketConn + bufPool sync.Pool + ctx context.Context + cancel context.CancelFunc } type idleConn struct { @@ -44,13 +49,14 @@ type idleConn struct { conn *udpPacketConn } -func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder { +func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { ctx, cancel := context.WithCancel(context.Background()) f := &udpForwarder{ - logger: logger, - conns: make(map[stack.TransportEndpointID]*udpPacketConn), - ctx: ctx, - cancel: cancel, + logger: logger, + flowLogger: flowLogger, + conns: make(map[stack.TransportEndpointID]*udpPacketConn), + ctx: ctx, + cancel: cancel, bufPool: sync.Pool{ New: func() any { b := make([]byte, mtu) @@ -83,6 +89,21 @@ func (f *udpForwarder) Stop() { } } +// sendUDPEvent stores flow events for UDP connections +func (f *udpForwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) { + f.flowLogger.StoreEvent(nftypes.EventFields{ + FlowID: flowID, + Type: typ, + Direction: nftypes.Ingress, + Protocol: 17, + // TODO: handle ipv6 + SourceIP: netip.AddrFrom4(id.LocalAddress.As4()), + DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), + SourcePort: id.LocalPort, + DestPort: id.RemotePort, + }) +} + // cleanup periodically removes idle UDP connections func (f *udpForwarder) cleanup() { ticker := time.NewTicker(time.Minute) @@ -119,6 +140,8 @@ func (f *udpForwarder) cleanup() { f.Unlock() f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id) + + f.sendUDPEvent(nftypes.TypeEnd, idle.conn.flowID, idle.id) } } } @@ -141,10 +164,14 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { return } + flowID := uuid.New() + f.sendUDPEvent(nftypes.TypeStart, flowID, id) + dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) if err != nil { f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) + f.sendUDPEvent(nftypes.TypeEnd, flowID, id) // TODO: Send ICMP error message return } @@ -157,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { if err := outConn.Close(); err != nil { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) } + f.sendUDPEvent(nftypes.TypeEnd, flowID, id) return } @@ -168,6 +196,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { outConn: outConn, cancel: connCancel, ep: ep, + flowID: flowID, } pConn.updateLastSeen() @@ -182,6 +211,8 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { if err := outConn.Close(); err != nil { f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) } + + f.sendUDPEvent(nftypes.TypeEnd, flowID, id) return } f.udpForwarder.conns[id] = pConn @@ -206,6 +237,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack f.udpForwarder.Lock() delete(f.udpForwarder.conns, id) f.udpForwarder.Unlock() + + f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id) }() errChan := make(chan error, 2) @@ -231,6 +264,21 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack } } +// sendUDPEvent stores flow events for UDP connections, mirrors the TCP version +func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) { + f.flowLogger.StoreEvent(nftypes.EventFields{ + FlowID: flowID, + Type: typ, + Direction: nftypes.Ingress, + Protocol: 17, // UDP protocol number + // TODO: handle ipv6 + SourceIP: netip.AddrFrom4(id.LocalAddress.As4()), + DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), + SourcePort: id.LocalPort, + DestPort: id.RemotePort, + }) +} + func (c *udpPacketConn) updateLastSeen() { c.lastSeen.Store(time.Now().UnixNano()) } diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index 984b6ad08..c55df6e90 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -1,4 +1,4 @@ -// Package logger provides a high-performance, non-blocking logger for userspace networking +// Package log provides a high-performance, non-blocking logger for userspace networking package log import ( @@ -13,13 +13,12 @@ import ( ) const ( - maxBatchSize = 1024 * 16 // 16KB max batch size - maxMessageSize = 1024 * 2 // 2KB per message - bufferSize = 1024 * 256 // 256KB ring buffer + maxBatchSize = 1024 * 16 + maxMessageSize = 1024 * 2 defaultFlushInterval = 2 * time.Second + logChannelSize = 1000 ) -// Level represents log severity type Level uint32 const ( @@ -42,32 +41,37 @@ var levelStrings = map[Level]string{ LevelTrace: "TRAC", } -// Logger is a high-performance, non-blocking logger -type Logger struct { - output io.Writer - level atomic.Uint32 - buffer *ringBuffer - shutdown chan struct{} - closeOnce sync.Once - wg sync.WaitGroup - - // Reusable buffer pool for formatting messages - bufPool sync.Pool +type logMessage struct { + level Level + format string + args []any } +// Logger is a high-performance, non-blocking logger +type Logger struct { + output io.Writer + level atomic.Uint32 + msgChannel chan logMessage + shutdown chan struct{} + closeOnce sync.Once + wg sync.WaitGroup + bufPool sync.Pool +} + +// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger func NewFromLogrus(logrusLogger *log.Logger) *Logger { l := &Logger{ - output: logrusLogger.Out, - buffer: newRingBuffer(bufferSize), - shutdown: make(chan struct{}), + output: logrusLogger.Out, + msgChannel: make(chan logMessage, logChannelSize), + shutdown: make(chan struct{}), bufPool: sync.Pool{ - New: func() interface{} { - // Pre-allocate buffer for message formatting + New: func() any { b := make([]byte, 0, maxMessageSize) return &b }, }, } + logrusLevel := logrusLogger.GetLevel() l.level.Store(uint32(logrusLevel)) level := levelStrings[Level(logrusLevel)] @@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger { return l } +// SetLevel sets the logging level func (l *Logger) SetLevel(level Level) { l.level.Store(uint32(level)) - log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } -func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) { - *buf = (*buf)[:0] - - // Timestamp - *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") - *buf = append(*buf, ' ') - - // Level - *buf = append(*buf, levelStrings[level]...) - *buf = append(*buf, ' ') - - // Message - if len(args) > 0 { - *buf = append(*buf, fmt.Sprintf(format, args...)...) - } else { - *buf = append(*buf, format...) +func (l *Logger) log(level Level, format string, args ...any) { + select { + case l.msgChannel <- logMessage{level: level, format: format, args: args}: + default: } - - *buf = append(*buf, '\n') } -func (l *Logger) log(level Level, format string, args ...interface{}) { - bufp := l.bufPool.Get().(*[]byte) - l.formatMessage(bufp, level, format, args...) - - if len(*bufp) > maxMessageSize { - *bufp = (*bufp)[:maxMessageSize] - } - _, _ = l.buffer.Write(*bufp) - - l.bufPool.Put(bufp) -} - -func (l *Logger) Error(format string, args ...interface{}) { +// Error logs a message at error level +func (l *Logger) Error(format string, args ...any) { if l.level.Load() >= uint32(LevelError) { l.log(LevelError, format, args...) } } -func (l *Logger) Warn(format string, args ...interface{}) { +// Warn logs a message at warning level +func (l *Logger) Warn(format string, args ...any) { if l.level.Load() >= uint32(LevelWarn) { l.log(LevelWarn, format, args...) } } -func (l *Logger) Info(format string, args ...interface{}) { +// Info logs a message at info level +func (l *Logger) Info(format string, args ...any) { if l.level.Load() >= uint32(LevelInfo) { l.log(LevelInfo, format, args...) } } -func (l *Logger) Debug(format string, args ...interface{}) { +// Debug logs a message at debug level +func (l *Logger) Debug(format string, args ...any) { if l.level.Load() >= uint32(LevelDebug) { l.log(LevelDebug, format, args...) } } -func (l *Logger) Trace(format string, args ...interface{}) { +// Trace logs a message at trace level +func (l *Logger) Trace(format string, args ...any) { if l.level.Load() >= uint32(LevelTrace) { l.log(LevelTrace, format, args...) } } -// worker periodically flushes the buffer +func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) { + *buf = (*buf)[:0] + *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05.000-07:00") + *buf = append(*buf, ' ') + *buf = append(*buf, levelStrings[level]...) + *buf = append(*buf, ' ') + + var msg string + if len(args) > 0 { + msg = fmt.Sprintf(format, args...) + } else { + msg = format + } + *buf = append(*buf, msg...) + *buf = append(*buf, '\n') + + if len(*buf) > maxMessageSize { + *buf = (*buf)[:maxMessageSize] + } +} + +// processMessage handles a single log message and adds it to the buffer +func (l *Logger) processMessage(msg logMessage, buffer *[]byte) { + bufp := l.bufPool.Get().(*[]byte) + defer l.bufPool.Put(bufp) + + l.formatMessage(bufp, msg.level, msg.format, msg.args...) + + if len(*buffer)+len(*bufp) > maxBatchSize { + _, _ = l.output.Write(*buffer) + *buffer = (*buffer)[:0] + } + *buffer = append(*buffer, *bufp...) +} + +// flushBuffer writes the accumulated buffer to output +func (l *Logger) flushBuffer(buffer *[]byte) { + if len(*buffer) > 0 { + _, _ = l.output.Write(*buffer) + *buffer = (*buffer)[:0] + } +} + +// processBatch processes as many messages as possible without blocking +func (l *Logger) processBatch(buffer *[]byte) { + for len(*buffer) < maxBatchSize { + select { + case msg := <-l.msgChannel: + l.processMessage(msg, buffer) + default: + return + } + } +} + +// handleShutdown manages the graceful shutdown sequence with timeout +func (l *Logger) handleShutdown(buffer *[]byte) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + for { + select { + case msg := <-l.msgChannel: + l.processMessage(msg, buffer) + case <-ctx.Done(): + l.flushBuffer(buffer) + return + } + + if len(l.msgChannel) == 0 { + l.flushBuffer(buffer) + return + } + } +} + +// worker is the main goroutine that processes log messages func (l *Logger) worker() { defer l.wg.Done() ticker := time.NewTicker(defaultFlushInterval) defer ticker.Stop() - buf := make([]byte, 0, maxBatchSize) + buffer := make([]byte, 0, maxBatchSize) for { select { case <-l.shutdown: + l.handleShutdown(&buffer) return case <-ticker.C: - // Read accumulated messages - n, _ := l.buffer.Read(buf[:cap(buf)]) - if n == 0 { - continue - } - - // Write batch - _, _ = l.output.Write(buf[:n]) + l.flushBuffer(&buffer) + case msg := <-l.msgChannel: + l.processMessage(msg, &buffer) + l.processBatch(&buffer) } } } diff --git a/client/firewall/uspfilter/log/log_test.go b/client/firewall/uspfilter/log/log_test.go new file mode 100644 index 000000000..e7da9a8e9 --- /dev/null +++ b/client/firewall/uspfilter/log/log_test.go @@ -0,0 +1,121 @@ +package log_test + +import ( + "context" + "testing" + "time" + + "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/log" +) + +type discard struct{} + +func (d *discard) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func BenchmarkLogger(b *testing.B) { + simpleMessage := "Connection established" + + conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d" + srcIP := "192.168.1.1" + srcPort := uint16(12345) + dstIP := "10.0.0.1" + dstPort := uint16(443) + state := 4 // TCPStateEstablished + + complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s" + protocol := "TCP" + direction := "outbound" + flags := uint16(0x18) // ACK + PSH + sequence := uint32(123456789) + acknowledged := uint32(987654321) + payloadSize := 1460 + fragmented := false + connID := "f7a12b3e-c456-7890-d123-456789abcdef" + + b.Run("SimpleMessage", func(b *testing.B) { + logger := createTestLogger() + defer cleanupLogger(logger) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Trace(simpleMessage) + } + }) + + b.Run("ConntrackMessage", func(b *testing.B) { + logger := createTestLogger() + defer cleanupLogger(logger) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state) + } + }) + + b.Run("ComplexMessage", func(b *testing.B) { + logger := createTestLogger() + defer cleanupLogger(logger) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID) + } + }) +} + +// BenchmarkLoggerParallel tests the logger under concurrent load +func BenchmarkLoggerParallel(b *testing.B) { + logger := createTestLogger() + defer cleanupLogger(logger) + + conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d" + srcIP := "192.168.1.1" + srcPort := uint16(12345) + dstIP := "10.0.0.1" + dstPort := uint16(443) + state := 4 + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state) + } + }) +} + +// BenchmarkLoggerBurst tests how the logger handles bursts of messages +func BenchmarkLoggerBurst(b *testing.B) { + logger := createTestLogger() + defer cleanupLogger(logger) + + conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d" + srcIP := "192.168.1.1" + srcPort := uint16(12345) + dstIP := "10.0.0.1" + dstPort := uint16(443) + state := 4 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < 100; j++ { + logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state) + } + } +} + +func createTestLogger() *log.Logger { + logrusLogger := logrus.New() + logrusLogger.SetOutput(&discard{}) + logrusLogger.SetLevel(logrus.TraceLevel) + return log.NewFromLogrus(logrusLogger) +} + +func cleanupLogger(logger *log.Logger) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _ = logger.Stop(ctx) +} diff --git a/client/firewall/uspfilter/log/ringbuffer.go b/client/firewall/uspfilter/log/ringbuffer.go deleted file mode 100644 index dbc8f1289..000000000 --- a/client/firewall/uspfilter/log/ringbuffer.go +++ /dev/null @@ -1,85 +0,0 @@ -package log - -import "sync" - -// ringBuffer is a simple ring buffer implementation -type ringBuffer struct { - buf []byte - size int - r, w int64 // Read and write positions - mu sync.Mutex -} - -func newRingBuffer(size int) *ringBuffer { - return &ringBuffer{ - buf: make([]byte, size), - size: size, - } -} - -func (r *ringBuffer) Write(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } - - r.mu.Lock() - defer r.mu.Unlock() - - if len(p) > r.size { - p = p[:r.size] - } - - n = len(p) - - // Write data, handling wrap-around - pos := int(r.w % int64(r.size)) - writeLen := min(len(p), r.size-pos) - copy(r.buf[pos:], p[:writeLen]) - - // If we have more data and need to wrap around - if writeLen < len(p) { - copy(r.buf, p[writeLen:]) - } - - // Update write position - r.w += int64(n) - - return n, nil -} - -func (r *ringBuffer) Read(p []byte) (n int, err error) { - r.mu.Lock() - defer r.mu.Unlock() - - if r.w == r.r { - return 0, nil - } - - // Calculate available data accounting for wraparound - available := int(r.w - r.r) - if available < 0 { - available += r.size - } - available = min(available, r.size) - - // Limit read to buffer size - toRead := min(available, len(p)) - if toRead == 0 { - return 0, nil - } - - // Read data, handling wrap-around - pos := int(r.r % int64(r.size)) - readLen := min(toRead, r.size-pos) - n = copy(p, r.buf[pos:pos+readLen]) - - // If we need more data and need to wrap around - if readLen < toRead { - n += copy(p[readLen:toRead], r.buf[:toRead-readLen]) - } - - // Update read position - r.r += int64(n) - - return n, nil -} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 193526a52..baccab3fb 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/iface/netstack" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -96,6 +97,7 @@ type Manager struct { tcpTracker *conntrack.TCPTracker forwarder *forwarder.Forwarder logger *nblog.Logger + flowLogger nftypes.FlowLogger } // decoder for packages @@ -112,16 +114,16 @@ type decoder struct { } // Create userspace firewall manager constructor -func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) { - return create(iface, nil, disableServerRoutes) +func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { + return create(iface, nil, disableServerRoutes, flowLogger) } -func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { +func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { if nativeFirewall == nil { return nil, errors.New("native firewall is nil") } - mgr, err := create(iface, nativeFirewall, disableServerRoutes) + mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger) if err != nil { return nil, err } @@ -148,7 +150,7 @@ func parseCreateEnv() (bool, bool) { return disableConntrack, enableLocalForwarding } -func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { +func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { disableConntrack, enableLocalForwarding := parseCreateEnv() m := &Manager{ @@ -174,6 +176,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe routingEnabled: false, stateful: !disableConntrack, logger: nblog.NewFromLogrus(log.StandardLogger()), + flowLogger: flowLogger, netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, } @@ -185,9 +188,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe if disableConntrack { log.Info("conntrack is disabled") } else { - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger) } // netstack needs the forwarder for local traffic @@ -304,7 +307,7 @@ func (m *Manager) initForwarder() error { return errors.New("forwarding not supported") } - forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack) + forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) if err != nil { m.routingEnabled = false return fmt.Errorf("create forwarder: %w", err) @@ -533,14 +536,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { // Track all protocols if stateful mode is enabled if m.stateful { - switch d.decoded[1] { - case layers.LayerTypeUDP: - m.trackUDPOutbound(d, srcIP, dstIP) - case layers.LayerTypeTCP: - m.trackTCPOutbound(d, srcIP, dstIP) - case layers.LayerTypeICMPv4: - m.trackICMPOutbound(d, srcIP, dstIP) - } + m.trackOutbound(d, srcIP, dstIP) } // Process UDP hooks even if stateful mode is disabled @@ -562,17 +558,6 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { } } -func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) { - flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound( - srcIP, - dstIP, - uint16(d.tcp.SrcPort), - uint16(d.tcp.DstPort), - flags, - ) -} - func getTCPFlags(tcp *layers.TCP) uint8 { var flags uint8 if tcp.SYN { @@ -596,13 +581,34 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) { - m.udpTracker.TrackOutbound( - srcIP, - dstIP, - uint16(d.udp.SrcPort), - uint16(d.udp.DstPort), - ) +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) { + transport := d.decoded[1] + switch transport { + case layers.LayerTypeUDP: + m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort)) + case layers.LayerTypeTCP: + flags := getTCPFlags(&d.tcp) + m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) + case layers.LayerTypeICMPv4: + if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { + m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq) + } + } +} + +func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) { + transport := d.decoded[1] + switch transport { + case layers.LayerTypeUDP: + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort)) + case layers.LayerTypeTCP: + flags := getTCPFlags(&d.tcp) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) + case layers.LayerTypeICMPv4: + if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { + m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq) + } + } } func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { @@ -618,17 +624,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo return false } -func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { - if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { - m.icmpTracker.TrackOutbound( - srcIP, - dstIP, - d.icmp4.Id, - d.icmp4.Seq, - ) - } -} - // dropFilter implements filtering logic for incoming packets. // If it returns true, the packet should be dropped. func (m *Manager) dropFilter(packetData []byte) bool { @@ -675,6 +670,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData return m.handleNetstackLocalTraffic(packetData) } + // track inbound packets to get the correct direction and session id for flows + m.trackInbound(d, srcIP, dstIP) + return false } diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 875bb2425..62749ec49 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -158,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Create manager and basic setup manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -203,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -251,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -450,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -577,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -668,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -787,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -875,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index 9a1456d00..1a8ce10af 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -34,7 +34,7 @@ func TestPeerACLFiltering(t *testing.T) { }, } - manager, err := Create(ifaceMock, false) + manager, err := Create(ifaceMock, false, flowLogger) require.NoError(t, err) require.NotNil(t, manager) @@ -302,7 +302,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { }, } - manager, err := Create(ifaceMock, false) + manager, err := Create(ifaceMock, false, flowLogger) require.NoError(tb, manager.EnableRouting()) require.NoError(tb, err) require.NotNil(tb, manager) diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 386fa982b..d9576a1c0 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -1,8 +1,10 @@ package uspfilter import ( + "context" "fmt" "net" + "net/netip" "sync" "testing" "time" @@ -18,9 +20,11 @@ import ( "github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/netflow" ) var logger = log.NewFromLogrus(logrus.StandardLogger()) +var flowLogger = netflow.NewManager(context.Background()).GetLogger() type IFaceMock struct { SetFilterFunc func(device.PacketFilter) error @@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false) + m, err := Create(ifaceMock, false, flowLogger) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { }, } - m, err := Create(ifaceMock, false) + m, err := Create(ifaceMock, false, flowLogger) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -116,7 +120,7 @@ func TestManagerDeleteRule(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false) + m, err := Create(ifaceMock, false, flowLogger) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -187,7 +191,7 @@ func TestAddUDPPacketHook(t *testing.T) { t.Run(tt.name, func(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -236,7 +240,7 @@ func TestManagerReset(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false) + m, err := Create(ifaceMock, false, flowLogger) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -279,7 +283,7 @@ func TestNotMatchByIP(t *testing.T) { }, } - m, err := Create(ifaceMock, false) + m, err := Create(ifaceMock, false, flowLogger) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -347,7 +351,7 @@ func TestRemovePacketHook(t *testing.T) { } // creating manager instance - manager, err := Create(iface, false) + manager, err := Create(iface, false, flowLogger) if err != nil { t.Fatalf("Failed to create Manager: %s", err) } @@ -393,7 +397,7 @@ func TestRemovePacketHook(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) require.NoError(t, err) manager.wgNetwork = &net.IPNet{ @@ -401,7 +405,7 @@ func TestProcessOutgoingHooks(t *testing.T) { Mask: net.CIDRMask(16, 32), } manager.udpTracker.Close() - manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger) + manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) defer func() { require.NoError(t, manager.Reset(nil)) }() @@ -479,7 +483,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false) + manager, err := Create(ifaceMock, false, flowLogger) require.NoError(t, err) time.Sleep(time.Second) @@ -506,7 +510,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false) + }, false, flowLogger) require.NoError(t, err) manager.wgNetwork = &net.IPNet{ @@ -515,7 +519,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { } manager.udpTracker.Close() // Close the existing tracker - manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger) + manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger) manager.decoders = sync.Pool{ New: func() any { d := &decoder{ @@ -534,8 +538,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { }() // Set up packet parameters - srcIP := net.ParseIP("100.10.0.1") - dstIP := net.ParseIP("100.10.0.100") + srcIP := netip.MustParseAddr("100.10.0.1") + dstIP := netip.MustParseAddr("100.10.0.100") srcPort := uint16(51334) dstPort := uint16(53) @@ -543,8 +547,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { outboundIPv4 := &layers.IPv4{ TTL: 64, Version: 4, - SrcIP: srcIP, - DstIP: dstIP, + SrcIP: srcIP.AsSlice(), + DstIP: dstIP.AsSlice(), Protocol: layers.IPProtocolUDP, } outboundUDP := &layers.UDP{ @@ -573,11 +577,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { require.False(t, drop, "Initial outbound packet should not be dropped") // Verify connection was tracked - conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort) require.True(t, exists, "Connection should be tracked after outbound packet") - require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") - require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") + require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match") + require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match") require.Equal(t, srcPort, conn.SourcePort, "Source port should match") require.Equal(t, dstPort, conn.DestPort, "Destination port should match") @@ -585,8 +589,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { inboundIPv4 := &layers.IPv4{ TTL: 64, Version: 4, - SrcIP: dstIP, // Original destination is now source - DstIP: srcIP, // Original source is now destination + SrcIP: dstIP.AsSlice(), // Original destination is now source + DstIP: srcIP.AsSlice(), // Original source is now destination Protocol: layers.IPProtocolUDP, } inboundUDP := &layers.UDP{ @@ -641,7 +645,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { // If the connection should still be valid, verify it exists if cp.shouldAllow { - conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort) require.True(t, exists, "Connection should still exist during valid window") require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), "LastSeen should be updated for valid responses") diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index c9cbe1c5a..04da0e7d5 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -1,6 +1,7 @@ package acl import ( + "context" "net" "testing" @@ -10,9 +11,12 @@ import ( "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/acl/mocks" + "github.com/netbirdio/netbird/client/internal/netflow" mgmProto "github.com/netbirdio/netbird/management/proto" ) +var flowLogger = netflow.NewManager(context.Background()).GetLogger() + func TestDefaultManager(t *testing.T) { networkMap := &mgmProto.NetworkMap{ FirewallRules: []*mgmProto.FirewallRule{ @@ -52,7 +56,7 @@ func TestDefaultManager(t *testing.T) { ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(ifaceMock, nil, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) if err != nil { t.Errorf("create firewall: %v", err) return @@ -346,7 +350,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(ifaceMock, nil, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) if err != nil { t.Errorf("create firewall: %v", err) return diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 94b87124b..4741b9e1d 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" + "github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" @@ -29,6 +30,8 @@ import ( "github.com/netbirdio/netbird/formatter" ) +var flowLogger = netflow.NewManager(context.Background()).GetLogger() + type mocWGIface struct { filter device.PacketFilter } @@ -916,7 +919,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return nil, err } - pf, err := uspfilter.Create(wgIface, false) + pf, err := uspfilter.Create(wgIface, false, flowLogger) if err != nil { t.Fatalf("failed to create uspfilter: %v", err) return nil, err diff --git a/client/internal/engine.go b/client/internal/engine.go index db0b8f38e..16eeff1c9 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -35,7 +35,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/netflow" - "github.com/netbirdio/netbird/client/internal/netflow/types" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer/guard" @@ -191,7 +191,7 @@ type Engine struct { persistNetworkMap bool latestNetworkMap *mgmProto.NetworkMap connSemaphore *semaphoregroup.SemaphoreGroup - flowManager types.FlowManager + flowManager nftypes.FlowManager } // Peer is an instance of the Connection Peer @@ -454,7 +454,7 @@ func (e *Engine) createFirewall() error { } var err error - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes) if err != nil || e.firewall == nil { log.Errorf("failed creating firewall manager: %s", err) return nil @@ -721,11 +721,11 @@ func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error { return e.flowManager.Update(flowConfig) } -func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*types.FlowConfig, error) { +func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) { if config.GetInterval() == nil { return nil, errors.New("flow interval is nil") } - return &types.FlowConfig{ + return &nftypes.FlowConfig{ Enabled: config.GetEnabled(), URL: config.GetUrl(), TokenPayload: config.GetTokenPayload(), diff --git a/client/internal/netflow/types/types.go b/client/internal/netflow/types/types.go index 41c2bc0f1..dd4b60889 100644 --- a/client/internal/netflow/types/types.go +++ b/client/internal/netflow/types/types.go @@ -7,17 +7,52 @@ import ( "github.com/google/uuid" ) +type Protocol uint8 + +const ( + ProtocolUnknown = 0 + ICMP = 1 + TCP = 6 + UDP = 17 +) + +func (p Protocol) String() string { + switch p { + case 1: + return "ICMP" + case 6: + return "TCP" + case 17: + return "UDP" + default: + return "unknown" + } +} + type Type int const ( - TypeStart = iota + TypeUnknown = iota + TypeStart TypeEnd ) type Direction int +func (d Direction) String() string { + switch d { + case Ingress: + return "ingress" + case Egress: + return "egress" + default: + return "unknown" + } +} + const ( - Ingress = iota + DirectionUnknown = iota + Ingress Egress )