Add userspace flow implementation (#3393)

This commit is contained in:
Viktor Liu 2025-02-28 11:08:35 +01:00 committed by GitHub
parent cccc615783
commit fa748a7ec2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 862 additions and 569 deletions

View File

@ -10,17 +10,18 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }
// use userspace packet filtering firewall // use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes) fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -15,6 +15,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
if err != nil { if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
} }
return createUserspaceFirewall(iface, fm, disableServerRoutes) return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
} }
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
} }
} }
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) { func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
var errUsp error var errUsp error
if fm != nil { if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
} else { } else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes) fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
} }
if errUsp != nil { if errUsp != nil {

View File

@ -22,17 +22,17 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if m.forwarder != nil { if m.forwarder != nil {

View File

@ -31,17 +31,17 @@ func (m *Manager) Reset(*statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if m.forwarder != nil { if m.forwarder != nil {

View File

@ -1,20 +1,26 @@
// common.go
package conntrack package conntrack
import ( import (
"fmt"
"net" "net"
"sync" "net/netip"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
SourceIP net.IP FlowId uuid.UUID
DestIP net.IP Direction nftypes.Direction
SourceIP netip.Addr
DestIP netip.Addr
SourcePort uint16 SourcePort uint16
DestPort uint16 DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access lastSeen atomic.Int64
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler
@ -35,92 +41,27 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
return time.Since(lastSeen) > timeout return time.Since(lastSeen) > timeout
} }
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection // ConnKey uniquely identifies a connection
type ConnKey struct { type ConnKey struct {
SrcIP IPAddr SrcIP netip.Addr
DstIP IPAddr DstIP netip.Addr
SrcPort uint16 SrcPort uint16
DstPort uint16 DstPort uint16
} }
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}
// makeConnKey creates a connection key // makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ConnKey{ return ConnKey{
SrcIP: MakeIPAddr(srcIP), SrcIP: srcAddr,
DstIP: MakeIPAddr(dstIP), DstIP: dstAddr,
SrcPort: srcPort, SrcPort: srcPort,
DstPort: dstPort, DstPort: dstPort,
} }
} }
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
}

View File

@ -1,50 +1,23 @@
package conntrack package conntrack
import ( import (
"context"
"net" "net"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/netflow"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
@ -69,7 +42,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
}) })
b.Run("UDPHighLoad", func(b *testing.B) { b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs

View File

@ -1,13 +1,17 @@
package conntrack package conntrack
import ( import (
"fmt"
"net" "net"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@ -19,18 +23,19 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct { type ICMPConnKey struct {
// Supports both IPv4 and IPv6 SrcIP netip.Addr
SrcIP [16]byte DstIP netip.Addr
DstIP [16]byte Sequence uint16
Sequence uint16 // ICMP sequence number ID uint16
ID uint16 // ICMP identifier }
func (i *ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.Sequence, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct { type ICMPConnTrack struct {
BaseConnTrack BaseConnTrack
Sequence uint16
ID uint16
} }
// ICMPTracker manages ICMP connection states // ICMPTracker manages ICMP connection states
@ -41,11 +46,11 @@ type ICMPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{} done chan struct{}
ipPool *PreallocatedIPs flowLogger nftypes.FlowLogger
} }
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
@ -56,41 +61,65 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}), done: make(chan struct{}),
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine()
return tracker return tracker
} }
// TrackOutbound records an outbound ICMP Echo Request func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) {
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
key := makeICMPKey(srcIP, dstIP, id, seq) key := makeICMPKey(srcIP, dstIP, id, seq)
t.mutex.Lock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
if !exists { t.mutex.RUnlock()
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &ICMPConnTrack{ if exists {
conn.UpdateLastSeen()
return key, true
}
return key, false
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, seq, nftypes.Egress)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
t.track(srcIP, dstIP, id, seq, nftypes.Ingress)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
if exists {
return
}
conn := &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy, FlowId: uuid.New(),
DestIP: dstIPCopy, Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
}, },
ID: id,
Sequence: seq,
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New ICMP connection %v", key) t.mutex.Lock()
} t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
conn.UpdateLastSeen() t.logger.Trace("New %s ICMP connection %s", conn.Direction, key)
t.sendEvent(nftypes.TypeStart, key, conn)
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
@ -105,18 +134,13 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists { if !exists || conn.timeoutExceeded(t.timeout) {
return false return false
} }
if conn.timeoutExceeded(t.timeout) { conn.UpdateLastSeen()
return false
}
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && return true
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
} }
func (t *ICMPTracker) cleanupRoutine() { func (t *ICMPTracker) cleanupRoutine() {
@ -129,17 +153,17 @@ func (t *ICMPTracker) cleanupRoutine() {
} }
} }
} }
func (t *ICMPTracker) cleanup() { func (t *ICMPTracker) cleanup() {
t.mutex.Lock() t.mutex.Lock()
defer t.mutex.Unlock() defer t.mutex.Unlock()
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %v (timeout)", key) t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
t.sendEvent(nftypes.TypeEnd, key, conn)
} }
} }
} }
@ -150,19 +174,29 @@ func (t *ICMPTracker) Close() {
close(t.done) close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: key.SrcIP,
DestIP: key.DstIP,
// TODO: add icmp code/type,
})
}
// makeICMPKey creates an ICMP connection key // makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ICMPConnKey{ return ICMPConnKey{
SrcIP: MakeIPAddr(srcIP), SrcIP: srcAddr,
DstIP: MakeIPAddr(dstIP), DstIP: dstAddr,
ID: id, ID: id,
Sequence: seq, Sequence: seq,
} }

View File

@ -7,7 +7,7 @@ import (
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")

View File

@ -8,7 +8,10 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@ -39,6 +42,35 @@ const (
// TCPState represents the state of a TCP connection // TCPState represents the state of a TCP connection
type TCPState int type TCPState int
func (s TCPState) String() string {
switch s {
case TCPStateNew:
return "New"
case TCPStateSynSent:
return "SYN Sent"
case TCPStateSynReceived:
return "SYN Received"
case TCPStateEstablished:
return "Established"
case TCPStateFinWait1:
return "FIN Wait 1"
case TCPStateFinWait2:
return "FIN Wait 2"
case TCPStateClosing:
return "Closing"
case TCPStateTimeWait:
return "Time Wait"
case TCPStateCloseWait:
return "Close Wait"
case TCPStateLastAck:
return "Last ACK"
case TCPStateClosed:
return "Closed"
default:
return "Unknown"
}
}
const ( const (
TCPStateNew TCPState = iota TCPStateNew TCPState = iota
TCPStateSynSent TCPStateSynSent
@ -53,19 +85,12 @@ const (
TCPStateClosed TCPStateClosed
) )
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state // TCPConnTrack represents a TCP connection state
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
State TCPState State TCPState
established atomic.Bool established atomic.Bool
tombstone atomic.Bool
sync.RWMutex sync.RWMutex
} }
@ -79,6 +104,16 @@ func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state) t.established.Store(state)
} }
// IsTombstone safely checks if the connection is marked for deletion
func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
}
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
type TCPTracker struct { type TCPTracker struct {
logger *nblog.Logger logger *nblog.Logger
@ -87,68 +122,94 @@ type TCPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
done chan struct{} done chan struct{}
timeout time.Duration timeout time.Duration
ipPool *PreallocatedIPs flowLogger nftypes.FlowLogger
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
if timeout == 0 {
timeout = DefaultTCPTimeout
}
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}), done: make(chan struct{}),
timeout: timeout, timeout: timeout,
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine()
return tracker return tracker
} }
// TrackOutbound processes an outbound TCP packet and updates connection state func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
// Create key before lock
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
t.mutex.Lock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
if !exists { t.mutex.RUnlock()
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &TCPConnTrack{ if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.UpdateLastSeen()
conn.Unlock()
return key, true
}
return key, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
if exists {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy, FlowId: uuid.New(),
DestIP: dstIPCopy, Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
}, },
State: TCPStateNew,
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.established.Store(false) conn.established.Store(false)
t.connections[key] = conn conn.tombstone.Store(false)
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort) t.logger.Trace("New %s TCP connection: %s", direction, key)
} t.updateState(key, conn, flags, direction == nftypes.Egress)
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
// Lock individual connection for state update t.sendEvent(nftypes.TypeStart, key, conn)
conn.Lock()
t.updateState(conn, flags, true)
conn.Unlock()
conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
key := makeConnKey(dstIP, srcIP, dstPort, srcPort) key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock() t.mutex.RLock()
@ -159,21 +220,25 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false return false
} }
// Handle RST packets // Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
conn.Lock() if conn.IsTombstone() {
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
return true return true
} }
conn.Unlock()
return false
}
conn.Lock() conn.Lock()
t.updateState(conn, flags, false) conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
return true
}
conn.Lock()
t.updateState(key, conn, flags, false)
conn.UpdateLastSeen() conn.UpdateLastSeen()
isEstablished := conn.IsEstablished() isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags) isValidState := t.isValidStateForFlags(conn.State, flags)
@ -183,18 +248,15 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
} }
// updateState updates the TCP connection state based on flags // updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) { func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
// Handle RST flag specially - it always causes transition to closed state := conn.State
if flags&TCPRst != 0 { defer func() {
conn.State = TCPStateClosed if state != conn.State {
conn.SetEstablished(false) t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
return
} }
}()
switch conn.State { switch state {
case TCPStateNew: case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 { if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent conn.State = TCPStateSynSent
@ -241,6 +303,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateFinWait2: case TCPStateFinWait2:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
} }
case TCPStateClosing: case TCPStateClosing:
@ -248,8 +313,8 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
// Keep established = false from previous state // Keep established = false from previous state
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d", t.logger.Trace("TCP connection %s closed (simultaneous)", key)
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) t.sendEvent(nftypes.TypeEnd, key, conn)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
@ -260,17 +325,12 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateLastAck: case TCPStateLastAck:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetTombstone()
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d", // Send close event for gracefully closed connections
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) t.sendEvent(nftypes.TypeEnd, key, conn)
t.logger.Trace("TCP connection %s closed gracefully", key)
} }
case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
} }
@ -331,6 +391,12 @@ func (t *TCPTracker) cleanup() {
defer t.mutex.Unlock() defer t.mutex.Unlock()
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.IsTombstone() {
// Clean up tombstoned connections without sending an event
delete(t.connections, key)
continue
}
var timeout time.Duration var timeout time.Duration
switch { switch {
case conn.State == TCPStateTimeWait: case conn.State == TCPStateTimeWait:
@ -341,14 +407,16 @@ func (t *TCPTracker) cleanup() {
timeout = TCPHandshakeTimeout timeout = TCPHandshakeTimeout
} }
lastSeen := conn.GetLastSeen() if conn.timeoutExceeded(timeout) {
if time.Since(lastSeen) > timeout {
// Return IPs to pool // Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) t.logger.Trace("Cleaned up timed-out TCP connection %s", &key)
// event already handled by state change
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, key, conn)
}
} }
} }
} }
@ -360,10 +428,6 @@ func (t *TCPTracker) Close() {
// Clean up all remaining IPs // Clean up all remaining IPs
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
@ -381,3 +445,16 @@ func isValidFlagCombination(flags uint8) bool {
return true return true
} }
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: key.SrcPort,
DestPort: key.DstPort,
})
}

View File

@ -9,7 +9,7 @@ import (
) )
func TestTCPStateMachine(t *testing.T) { func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Helper() t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, logger) tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tt.test(t) tt.test(t)
}) })
} }
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
} }
func TestRSTHandling(t *testing.T) { func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
func BenchmarkTCPTracker(b *testing.B) { func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
}) })
b.Run("ConcurrentAccess", func(b *testing.B) { b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup // Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) { func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
defer tracker.Close() defer tracker.Close()
// Pre-populate with expired connections // Pre-populate with expired connections

View File

@ -5,7 +5,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@ -28,11 +31,11 @@ type UDPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{} done chan struct{}
ipPool *PreallocatedIPs flowLogger nftypes.FlowLogger
} }
// NewUDPTracker creates a new UDP connection tracker // NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
@ -43,7 +46,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}), done: make(chan struct{}),
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine()
@ -52,32 +55,57 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
}
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
}
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
t.mutex.Lock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
if !exists { t.mutex.RUnlock()
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &UDPConnTrack{ if exists {
conn.UpdateLastSeen()
return key, true
}
return key, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy, FlowId: uuid.New(),
DestIP: dstIPCopy, Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
}, },
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New UDP connection: %v", conn) t.mutex.Lock()
} t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
conn.UpdateLastSeen() t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, key, conn)
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // IsValidInbound checks if an inbound packet matches a tracked connection
@ -88,18 +116,13 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists { if !exists || conn.timeoutExceeded(t.timeout) {
return false return false
} }
if conn.timeoutExceeded(t.timeout) { conn.UpdateLastSeen()
return false
}
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && return true
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
} }
// cleanupRoutine periodically removes stale connections // cleanupRoutine periodically removes stale connections
@ -120,11 +143,10 @@ func (t *UDPTracker) cleanup() {
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %v (timeout)", conn) t.logger.Trace("Removed UDP connection %s (timeout)", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
} }
} }
} }
@ -135,10 +157,6 @@ func (t *UDPTracker) Close() {
close(t.done) close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
@ -150,14 +168,23 @@ func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, d
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := t.connections[key] conn, exists := t.connections[key]
if !exists { return conn, exists
return nil, false
}
return conn, true
} }
// Timeout returns the configured timeout duration for the tracker // Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration { func (t *UDPTracker) Timeout() time.Duration {
return t.timeout return t.timeout
} }
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: key.SrcPort,
DestPort: key.DstPort,
})
}

View File

@ -2,6 +2,7 @@ package conntrack
import ( import (
"net" "net"
"net/netip"
"testing" "testing"
"time" "time"
@ -29,7 +30,7 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, logger) tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
assert.NotNil(t, tracker) assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
@ -40,29 +41,29 @@ func TestNewUDPTracker(t *testing.T) {
} }
func TestUDPTracker_TrackOutbound(t *testing.T) { func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3") dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
// Verify connection was tracked // Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
conn, exists := tracker.connections[key] conn, exists := tracker.connections[key]
require.True(t, exists) require.True(t, exists)
assert.True(t, conn.SourceIP.Equal(srcIP)) assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
assert.True(t, conn.DestIP.Equal(dstIP)) assert.True(t, conn.DestIP.Compare(dstIP) == 0)
assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort) assert.Equal(t, dstPort, conn.DestPort)
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
} }
func TestUDPTracker_IsValidInbound(t *testing.T) { func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger) tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
@ -160,8 +161,8 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}), done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
logger: logger, logger: logger,
flowLogger: flowLogger,
} }
// Start cleanup routine // Start cleanup routine
@ -211,7 +212,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) { func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
@ -224,7 +225,7 @@ func BenchmarkUDPTracker(b *testing.B) {
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")

View File

@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@ -29,6 +30,7 @@ const (
type Forwarder struct { type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
@ -38,7 +40,7 @@ type Forwarder struct {
netstack bool netstack bool
} }
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{ f := &Forwarder{
logger: logger, logger: logger,
flowLogger: flowLogger,
stack: s, stack: s,
endpoint: endpoint, endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger), udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,

View File

@ -3,14 +3,21 @@ package forwarder
import ( import (
"context" "context"
"net" "net"
"net/netip"
"time" "time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// handleICMP handles ICMP packets from the network stack // handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel() defer cancel()
@ -20,6 +27,8 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err != nil { if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err) f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id)
// This will make netstack reply on behalf of the original destination, that's ok for now // This will make netstack reply on behalf of the original destination, that's ok for now
return false return false
} }
@ -27,6 +36,8 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err) f.logger.Debug("Failed to close ICMP socket: %v", err)
} }
f.sendICMPEvent(nftypes.TypeEnd, flowID, id)
}() }()
dstIP := f.determineDialAddr(id.LocalAddress) dstIP := f.determineDialAddr(id.LocalAddress)
@ -101,9 +112,25 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if err := f.InjectIncomingPacket(fullPacket); err != nil { if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err) f.logger.Error("Failed to inject ICMP response: %v", err)
return true return true
} }
f.logger.Trace("Forwarded ICMP echo reply for %v", id) f.logger.Trace("Forwarded ICMP echo reply for %v", id)
return true return true
} }
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 1,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}

View File

@ -5,18 +5,25 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// handleTCP is called by the TCP forwarder for new connections. // handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID() id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id)
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
@ -46,10 +53,10 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
f.logger.Trace("forwarder: established TCP connection %v", id) f.logger.Trace("forwarder: established TCP connection %v", id)
go f.proxyTCP(id, inConn, outConn, ep) go f.proxyTCP(id, inConn, outConn, ep, flowID)
} }
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) { func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() { defer func() {
if err := inConn.Close(); err != nil { if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err) f.logger.Debug("forwarder: inConn close error: %v", err)
@ -58,6 +65,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
f.logger.Debug("forwarder: outConn close error: %v", err) f.logger.Debug("forwarder: outConn close error: %v", err)
} }
ep.Close() ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id)
}() }()
// Create context for managing the proxy goroutines // Create context for managing the proxy goroutines
@ -88,3 +97,18 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
return return
} }
} }
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 6,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}

View File

@ -5,10 +5,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
@ -16,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@ -28,11 +31,13 @@ type udpPacketConn struct {
lastSeen atomic.Int64 lastSeen atomic.Int64
cancel context.CancelFunc cancel context.CancelFunc
ep tcpip.Endpoint ep tcpip.Endpoint
flowID uuid.UUID
} }
type udpForwarder struct { type udpForwarder struct {
sync.RWMutex sync.RWMutex
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger
conns map[stack.TransportEndpointID]*udpPacketConn conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool bufPool sync.Pool
ctx context.Context ctx context.Context
@ -44,10 +49,11 @@ type idleConn struct {
conn *udpPacketConn conn *udpPacketConn
} }
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder { func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{ f := &udpForwarder{
logger: logger, logger: logger,
flowLogger: flowLogger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn), conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@ -83,6 +89,21 @@ func (f *udpForwarder) Stop() {
} }
} }
// sendUDPEvent stores flow events for UDP connections
func (f *udpForwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 17,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
// cleanup periodically removes idle UDP connections // cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() { func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute) ticker := time.NewTicker(time.Minute)
@ -119,6 +140,8 @@ func (f *udpForwarder) cleanup() {
f.Unlock() f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id) f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
f.sendUDPEvent(nftypes.TypeEnd, idle.conn.flowID, idle.id)
} }
} }
} }
@ -141,10 +164,14 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
return return
} }
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id)
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
// TODO: Send ICMP error message // TODO: Send ICMP error message
return return
} }
@ -157,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
} }
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return return
} }
@ -168,6 +196,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
outConn: outConn, outConn: outConn,
cancel: connCancel, cancel: connCancel,
ep: ep, ep: ep,
flowID: flowID,
} }
pConn.updateLastSeen() pConn.updateLastSeen()
@ -182,6 +211,8 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
} }
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return return
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
@ -206,6 +237,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
f.udpForwarder.Lock() f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id) delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id)
}() }()
errChan := make(chan error, 2) errChan := make(chan error, 2)
@ -231,6 +264,21 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
} }
} }
// sendUDPEvent stores flow events for UDP connections, mirrors the TCP version
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 17, // UDP protocol number
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
func (c *udpPacketConn) updateLastSeen() { func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano()) c.lastSeen.Store(time.Now().UnixNano())
} }

View File

@ -1,4 +1,4 @@
// Package logger provides a high-performance, non-blocking logger for userspace networking // Package log provides a high-performance, non-blocking logger for userspace networking
package log package log
import ( import (
@ -13,13 +13,12 @@ import (
) )
const ( const (
maxBatchSize = 1024 * 16 // 16KB max batch size maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2 // 2KB per message maxMessageSize = 1024 * 2
bufferSize = 1024 * 256 // 256KB ring buffer
defaultFlushInterval = 2 * time.Second defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
) )
// Level represents log severity
type Level uint32 type Level uint32
const ( const (
@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
LevelTrace: "TRAC", LevelTrace: "TRAC",
} }
type logMessage struct {
level Level
format string
args []any
}
// Logger is a high-performance, non-blocking logger // Logger is a high-performance, non-blocking logger
type Logger struct { type Logger struct {
output io.Writer output io.Writer
level atomic.Uint32 level atomic.Uint32
buffer *ringBuffer msgChannel chan logMessage
shutdown chan struct{} shutdown chan struct{}
closeOnce sync.Once closeOnce sync.Once
wg sync.WaitGroup wg sync.WaitGroup
// Reusable buffer pool for formatting messages
bufPool sync.Pool bufPool sync.Pool
} }
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
func NewFromLogrus(logrusLogger *log.Logger) *Logger { func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{ l := &Logger{
output: logrusLogger.Out, output: logrusLogger.Out,
buffer: newRingBuffer(bufferSize), msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
bufPool: sync.Pool{ bufPool: sync.Pool{
New: func() interface{} { New: func() any {
// Pre-allocate buffer for message formatting
b := make([]byte, 0, maxMessageSize) b := make([]byte, 0, maxMessageSize)
return &b return &b
}, },
}, },
} }
logrusLevel := logrusLogger.GetLevel() logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel)) l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)] level := levelStrings[Level(logrusLevel)]
@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
return l return l
} }
// SetLevel sets the logging level
func (l *Logger) SetLevel(level Level) { func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level)) l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
} }
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) { func (l *Logger) log(level Level, format string, args ...any) {
*buf = (*buf)[:0] select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
// Timestamp default:
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
// Level
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
// Message
if len(args) > 0 {
*buf = append(*buf, fmt.Sprintf(format, args...)...)
} else {
*buf = append(*buf, format...)
} }
*buf = append(*buf, '\n')
} }
func (l *Logger) log(level Level, format string, args ...interface{}) { // Error logs a message at error level
bufp := l.bufPool.Get().(*[]byte) func (l *Logger) Error(format string, args ...any) {
l.formatMessage(bufp, level, format, args...)
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
_, _ = l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
func (l *Logger) Error(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...) l.log(LevelError, format, args...)
} }
} }
func (l *Logger) Warn(format string, args ...interface{}) { // Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
if l.level.Load() >= uint32(LevelWarn) { if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...) l.log(LevelWarn, format, args...)
} }
} }
func (l *Logger) Info(format string, args ...interface{}) { // Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
if l.level.Load() >= uint32(LevelInfo) { if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...) l.log(LevelInfo, format, args...)
} }
} }
func (l *Logger) Debug(format string, args ...interface{}) { // Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...) l.log(LevelDebug, format, args...)
} }
} }
func (l *Logger) Trace(format string, args ...interface{}) { // Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...) l.log(LevelTrace, format, args...)
} }
} }
// worker periodically flushes the buffer func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05.000-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
}
*buf = append(*buf, msg...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
*buf = (*buf)[:maxMessageSize]
}
}
// processMessage handles a single log message and adds it to the buffer
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
*buffer = append(*buffer, *bufp...)
}
// flushBuffer writes the accumulated buffer to output
func (l *Logger) flushBuffer(buffer *[]byte) {
if len(*buffer) > 0 {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
}
// processBatch processes as many messages as possible without blocking
func (l *Logger) processBatch(buffer *[]byte) {
for len(*buffer) < maxBatchSize {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
default:
return
}
}
}
// handleShutdown manages the graceful shutdown sequence with timeout
func (l *Logger) handleShutdown(buffer *[]byte) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
for {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
case <-ctx.Done():
l.flushBuffer(buffer)
return
}
if len(l.msgChannel) == 0 {
l.flushBuffer(buffer)
return
}
}
}
// worker is the main goroutine that processes log messages
func (l *Logger) worker() { func (l *Logger) worker() {
defer l.wg.Done() defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval) ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop() defer ticker.Stop()
buf := make([]byte, 0, maxBatchSize) buffer := make([]byte, 0, maxBatchSize)
for { for {
select { select {
case <-l.shutdown: case <-l.shutdown:
l.handleShutdown(&buffer)
return return
case <-ticker.C: case <-ticker.C:
// Read accumulated messages l.flushBuffer(&buffer)
n, _ := l.buffer.Read(buf[:cap(buf)]) case msg := <-l.msgChannel:
if n == 0 { l.processMessage(msg, &buffer)
continue l.processBatch(&buffer)
}
// Write batch
_, _ = l.output.Write(buf[:n])
} }
} }
} }

View File

@ -0,0 +1,121 @@
package log_test
import (
"context"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
type discard struct{}
func (d *discard) Write(p []byte) (n int, err error) {
return len(p), nil
}
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(simpleMessage)
}
})
b.Run("ConntrackMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
b.Run("ComplexMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
}
})
}
// BenchmarkLoggerParallel tests the logger under concurrent load
func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
}
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
}
}
func createTestLogger() *log.Logger {
logrusLogger := logrus.New()
logrusLogger.SetOutput(&discard{})
logrusLogger.SetLevel(logrus.TraceLevel)
return log.NewFromLogrus(logrusLogger)
}
func cleanupLogger(logger *log.Logger) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = logger.Stop(ctx)
}

View File

@ -1,85 +0,0 @@
package log
import "sync"
// ringBuffer is a simple ring buffer implementation
type ringBuffer struct {
buf []byte
size int
r, w int64 // Read and write positions
mu sync.Mutex
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (r *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if len(p) > r.size {
p = p[:r.size]
}
n = len(p)
// Write data, handling wrap-around
pos := int(r.w % int64(r.size))
writeLen := min(len(p), r.size-pos)
copy(r.buf[pos:], p[:writeLen])
// If we have more data and need to wrap around
if writeLen < len(p) {
copy(r.buf, p[writeLen:])
}
// Update write position
r.w += int64(n)
return n, nil
}
func (r *ringBuffer) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.w == r.r {
return 0, nil
}
// Calculate available data accounting for wraparound
available := int(r.w - r.r)
if available < 0 {
available += r.size
}
available = min(available, r.size)
// Limit read to buffer size
toRead := min(available, len(p))
if toRead == 0 {
return 0, nil
}
// Read data, handling wrap-around
pos := int(r.r % int64(r.size))
readLen := min(toRead, r.size-pos)
n = copy(p, r.buf[pos:pos+readLen])
// If we need more data and need to wrap around
if readLen < toRead {
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
}
// Update read position
r.r += int64(n)
return n, nil
}

View File

@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@ -96,6 +97,7 @@ type Manager struct {
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder forwarder *forwarder.Forwarder
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger
} }
// decoder for packages // decoder for packages
@ -112,16 +114,16 @@ type decoder struct {
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) { func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
return create(iface, nil, disableServerRoutes) return create(iface, nil, disableServerRoutes, flowLogger)
} }
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
if nativeFirewall == nil { if nativeFirewall == nil {
return nil, errors.New("native firewall is nil") return nil, errors.New("native firewall is nil")
} }
mgr, err := create(iface, nativeFirewall, disableServerRoutes) mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -148,7 +150,7 @@ func parseCreateEnv() (bool, bool) {
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding
} }
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv() disableConntrack, enableLocalForwarding := parseCreateEnv()
m := &Manager{ m := &Manager{
@ -174,6 +176,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
routingEnabled: false, routingEnabled: false,
stateful: !disableConntrack, stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()), logger: nblog.NewFromLogrus(log.StandardLogger()),
flowLogger: flowLogger,
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
} }
@ -185,9 +188,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
if disableConntrack { if disableConntrack {
log.Info("conntrack is disabled") log.Info("conntrack is disabled")
} else { } else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
} }
// netstack needs the forwarder for local traffic // netstack needs the forwarder for local traffic
@ -304,7 +307,7 @@ func (m *Manager) initForwarder() error {
return errors.New("forwarding not supported") return errors.New("forwarding not supported")
} }
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack) forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
if err != nil { if err != nil {
m.routingEnabled = false m.routingEnabled = false
return fmt.Errorf("create forwarder: %w", err) return fmt.Errorf("create forwarder: %w", err)
@ -533,14 +536,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
// Track all protocols if stateful mode is enabled // Track all protocols if stateful mode is enabled
if m.stateful { if m.stateful {
switch d.decoded[1] { m.trackOutbound(d, srcIP, dstIP)
case layers.LayerTypeUDP:
m.trackUDPOutbound(d, srcIP, dstIP)
case layers.LayerTypeTCP:
m.trackTCPOutbound(d, srcIP, dstIP)
case layers.LayerTypeICMPv4:
m.trackICMPOutbound(d, srcIP, dstIP)
}
} }
// Process UDP hooks even if stateful mode is disabled // Process UDP hooks even if stateful mode is disabled
@ -562,17 +558,6 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
} }
} }
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
flags,
)
}
func getTCPFlags(tcp *layers.TCP) uint8 { func getTCPFlags(tcp *layers.TCP) uint8 {
var flags uint8 var flags uint8
if tcp.SYN { if tcp.SYN {
@ -596,13 +581,34 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) { func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
m.udpTracker.TrackOutbound( transport := d.decoded[1]
srcIP, switch transport {
dstIP, case layers.LayerTypeUDP:
uint16(d.udp.SrcPort), m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
uint16(d.udp.DstPort), case layers.LayerTypeTCP:
) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
case layers.LayerTypeICMPv4:
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq)
}
}
}
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
case layers.LayerTypeICMPv4:
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq)
}
}
} }
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
@ -618,17 +624,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
return false return false
} }
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
m.icmpTracker.TrackOutbound(
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
)
}
}
// dropFilter implements filtering logic for incoming packets. // dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte) bool { func (m *Manager) dropFilter(packetData []byte) bool {
@ -675,6 +670,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
return m.handleNetstackLocalTraffic(packetData) return m.handleNetstackLocalTraffic(packetData)
} }
// track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP)
return false return false
} }

View File

@ -158,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup // Create manager and basic setup
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -203,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -251,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -450,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -577,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -668,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -787,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })
@ -875,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Reset(nil))
}) })

View File

@ -34,7 +34,7 @@ func TestPeerACLFiltering(t *testing.T) {
}, },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, manager) require.NotNil(t, manager)
@ -302,7 +302,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
}, },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, manager.EnableRouting()) require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err) require.NoError(tb, err)
require.NotNil(tb, manager) require.NotNil(tb, manager)

View File

@ -1,8 +1,10 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -18,9 +20,11 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/netflow"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@ -116,7 +120,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@ -187,7 +191,7 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@ -236,7 +240,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@ -279,7 +283,7 @@ func TestNotMatchByIP(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@ -347,7 +351,7 @@ func TestRemovePacketHook(t *testing.T) {
} }
// creating manager instance // creating manager instance
manager, err := Create(iface, false) manager, err := Create(iface, false, flowLogger)
if err != nil { if err != nil {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
@ -393,7 +397,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@ -401,7 +405,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32), Mask: net.CIDRMask(16, 32),
} }
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Reset(nil))
}() }()
@ -479,7 +483,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@ -506,7 +510,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@ -515,7 +519,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{
@ -534,8 +538,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}() }()
// Set up packet parameters // Set up packet parameters
srcIP := net.ParseIP("100.10.0.1") srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := net.ParseIP("100.10.0.100") dstIP := netip.MustParseAddr("100.10.0.100")
srcPort := uint16(51334) srcPort := uint16(51334)
dstPort := uint16(53) dstPort := uint16(53)
@ -543,8 +547,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
outboundIPv4 := &layers.IPv4{ outboundIPv4 := &layers.IPv4{
TTL: 64, TTL: 64,
Version: 4, Version: 4,
SrcIP: srcIP, SrcIP: srcIP.AsSlice(),
DstIP: dstIP, DstIP: dstIP.AsSlice(),
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
outboundUDP := &layers.UDP{ outboundUDP := &layers.UDP{
@ -573,11 +577,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.False(t, drop, "Initial outbound packet should not be dropped") require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked // Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet") require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match") require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match") require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
@ -585,8 +589,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
inboundIPv4 := &layers.IPv4{ inboundIPv4 := &layers.IPv4{
TTL: 64, TTL: 64,
Version: 4, Version: 4,
SrcIP: dstIP, // Original destination is now source SrcIP: dstIP.AsSlice(), // Original destination is now source
DstIP: srcIP, // Original source is now destination DstIP: srcIP.AsSlice(), // Original source is now destination
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
inboundUDP := &layers.UDP{ inboundUDP := &layers.UDP{
@ -641,7 +645,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
if cp.shouldAllow { if cp.shouldAllow {
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
require.True(t, exists, "Connection should still exist during valid window") require.True(t, exists, "Connection should still exist during valid window")
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
"LastSeen should be updated for valid responses") "LastSeen should be updated for valid responses")

View File

@ -1,6 +1,7 @@
package acl package acl
import ( import (
"context"
"net" "net"
"testing" "testing"
@ -10,9 +11,12 @@ import (
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
func TestDefaultManager(t *testing.T) { func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
@ -52,7 +56,7 @@ func TestDefaultManager(t *testing.T) {
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
@ -346,7 +350,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return

View File

@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@ -29,6 +30,8 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
) )
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
type mocWGIface struct { type mocWGIface struct {
filter device.PacketFilter filter device.PacketFilter
} }
@ -916,7 +919,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err return nil, err
} }
pf, err := uspfilter.Create(wgIface, false) pf, err := uspfilter.Create(wgIface, false, flowLogger)
if err != nil { if err != nil {
t.Fatalf("failed to create uspfilter: %v", err) t.Fatalf("failed to create uspfilter: %v", err)
return nil, err return nil, err

View File

@ -35,7 +35,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
@ -191,7 +191,7 @@ type Engine struct {
persistNetworkMap bool persistNetworkMap bool
latestNetworkMap *mgmProto.NetworkMap latestNetworkMap *mgmProto.NetworkMap
connSemaphore *semaphoregroup.SemaphoreGroup connSemaphore *semaphoregroup.SemaphoreGroup
flowManager types.FlowManager flowManager nftypes.FlowManager
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@ -454,7 +454,7 @@ func (e *Engine) createFirewall() error {
} }
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
if err != nil || e.firewall == nil { if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
return nil return nil
@ -721,11 +721,11 @@ func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
return e.flowManager.Update(flowConfig) return e.flowManager.Update(flowConfig)
} }
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*types.FlowConfig, error) { func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
if config.GetInterval() == nil { if config.GetInterval() == nil {
return nil, errors.New("flow interval is nil") return nil, errors.New("flow interval is nil")
} }
return &types.FlowConfig{ return &nftypes.FlowConfig{
Enabled: config.GetEnabled(), Enabled: config.GetEnabled(),
URL: config.GetUrl(), URL: config.GetUrl(),
TokenPayload: config.GetTokenPayload(), TokenPayload: config.GetTokenPayload(),

View File

@ -7,17 +7,52 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
type Protocol uint8
const (
ProtocolUnknown = 0
ICMP = 1
TCP = 6
UDP = 17
)
func (p Protocol) String() string {
switch p {
case 1:
return "ICMP"
case 6:
return "TCP"
case 17:
return "UDP"
default:
return "unknown"
}
}
type Type int type Type int
const ( const (
TypeStart = iota TypeUnknown = iota
TypeStart
TypeEnd TypeEnd
) )
type Direction int type Direction int
func (d Direction) String() string {
switch d {
case Ingress:
return "ingress"
case Egress:
return "egress"
default:
return "unknown"
}
}
const ( const (
Ingress = iota DirectionUnknown = iota
Ingress
Egress Egress
) )