Add userspace flow implementation (#3393)

This commit is contained in:
Viktor Liu 2025-02-28 11:08:35 +01:00 committed by GitHub
parent cccc615783
commit fa748a7ec2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 862 additions and 569 deletions

View File

@ -10,17 +10,18 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes)
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
if err != nil {
return nil, err
}

View File

@ -15,6 +15,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm, disableServerRoutes)
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
} else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
}
if errUsp != nil {

View File

@ -22,17 +22,17 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if m.forwarder != nil {

View File

@ -31,17 +31,17 @@ func (m *Manager) Reset(*statemanager.Manager) error {
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if m.forwarder != nil {

View File

@ -1,20 +1,26 @@
// common.go
package conntrack
import (
"fmt"
"net"
"sync"
"net/netip"
"sync/atomic"
"time"
"github.com/google/uuid"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
SourceIP net.IP
DestIP net.IP
FlowId uuid.UUID
Direction nftypes.Direction
SourceIP netip.Addr
DestIP netip.Addr
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access
lastSeen atomic.Int64
}
// these small methods will be inlined by the compiler
@ -35,92 +41,27 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
return time.Since(lastSeen) > timeout
}
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection
type ConnKey struct {
SrcIP IPAddr
DstIP IPAddr
SrcIP netip.Addr
DstIP netip.Addr
SrcPort uint16
DstPort uint16
}
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcIP: srcAddr,
DstIP: dstAddr,
SrcPort: srcPort,
DstPort: dstPort,
}
}
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
}

View File

@ -1,50 +1,23 @@
package conntrack
import (
"context"
"net"
"testing"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/netflow"
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
// Generate different IPs
@ -69,7 +42,7 @@ func BenchmarkMemoryPressure(b *testing.B) {
})
b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
// Generate different IPs

View File

@ -1,13 +1,17 @@
package conntrack
import (
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@ -19,18 +23,19 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
// Supports both IPv4 and IPv6
SrcIP [16]byte
DstIP [16]byte
Sequence uint16 // ICMP sequence number
ID uint16 // ICMP identifier
SrcIP netip.Addr
DstIP netip.Addr
Sequence uint16
ID uint16
}
func (i *ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.Sequence, i.ID)
}
// ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct {
BaseConnTrack
Sequence uint16
ID uint16
}
// ICMPTracker manages ICMP connection states
@ -41,11 +46,11 @@ type ICMPTracker struct {
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
flowLogger nftypes.FlowLogger
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}
@ -56,41 +61,65 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
flowLogger: flowLogger,
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) {
key := makeICMPKey(srcIP, dstIP, id, seq)
t.mutex.Lock()
t.mutex.RLock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
t.mutex.RUnlock()
conn = &ICMPConnTrack{
if exists {
conn.UpdateLastSeen()
return key, true
}
return key, false
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, seq, nftypes.Egress)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
t.track(srcIP, dstIP, id, seq, nftypes.Ingress)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
if exists {
return
}
conn := &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
FlowId: uuid.New(),
Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
},
ID: id,
Sequence: seq,
}
conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New ICMP connection %v", key)
}
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
conn.UpdateLastSeen()
t.logger.Trace("New %s ICMP connection %s", conn.Direction, key)
t.sendEvent(nftypes.TypeStart, key, conn)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
@ -105,18 +134,13 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
if !exists || conn.timeoutExceeded(t.timeout) {
return false
}
if conn.timeoutExceeded(t.timeout) {
return false
}
conn.UpdateLastSeen()
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
return true
}
func (t *ICMPTracker) cleanupRoutine() {
@ -129,17 +153,17 @@ func (t *ICMPTracker) cleanupRoutine() {
}
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
}
}
@ -150,19 +174,29 @@ func (t *ICMPTracker) Close() {
close(t.done)
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: key.SrcIP,
DestIP: key.DstIP,
// TODO: add icmp code/type,
})
}
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ICMPConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcIP: srcAddr,
DstIP: dstAddr,
ID: id,
Sequence: seq,
}

View File

@ -7,7 +7,7 @@ import (
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) {
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")

View File

@ -8,7 +8,10 @@ import (
"sync/atomic"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@ -39,6 +42,35 @@ const (
// TCPState represents the state of a TCP connection
type TCPState int
func (s TCPState) String() string {
switch s {
case TCPStateNew:
return "New"
case TCPStateSynSent:
return "SYN Sent"
case TCPStateSynReceived:
return "SYN Received"
case TCPStateEstablished:
return "Established"
case TCPStateFinWait1:
return "FIN Wait 1"
case TCPStateFinWait2:
return "FIN Wait 2"
case TCPStateClosing:
return "Closing"
case TCPStateTimeWait:
return "Time Wait"
case TCPStateCloseWait:
return "Close Wait"
case TCPStateLastAck:
return "Last ACK"
case TCPStateClosed:
return "Closed"
default:
return "Unknown"
}
}
const (
TCPStateNew TCPState = iota
TCPStateSynSent
@ -53,19 +85,12 @@ const (
TCPStateClosed
)
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
State TCPState
established atomic.Bool
tombstone atomic.Bool
sync.RWMutex
}
@ -79,6 +104,16 @@ func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state)
}
// IsTombstone safely checks if the connection is marked for deletion
func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
}
// TCPTracker manages TCP connection states
type TCPTracker struct {
logger *nblog.Logger
@ -87,68 +122,94 @@ type TCPTracker struct {
cleanupTicker *time.Ticker
done chan struct{}
timeout time.Duration
ipPool *PreallocatedIPs
flowLogger nftypes.FlowLogger
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
if timeout == 0 {
timeout = DefaultTCPTimeout
}
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}),
timeout: timeout,
ipPool: NewPreallocatedIPs(),
flowLogger: flowLogger,
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound processes an outbound TCP packet and updates connection state
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
// Create key before lock
func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
t.mutex.Lock()
t.mutex.RLock()
conn, exists := t.connections[key]
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
t.mutex.RUnlock()
conn = &TCPConnTrack{
if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.UpdateLastSeen()
conn.Unlock()
return key, true
}
return key, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
if exists {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
FlowId: uuid.New(),
Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: srcPort,
DestPort: dstPort,
},
State: TCPStateNew,
}
conn.UpdateLastSeen()
conn.established.Store(false)
t.connections[key] = conn
conn.tombstone.Store(false)
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
}
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction == nftypes.Egress)
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
// Lock individual connection for state update
conn.Lock()
t.updateState(conn, flags, true)
conn.Unlock()
conn.UpdateLastSeen()
t.sendEvent(nftypes.TypeStart, key, conn)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
@ -159,21 +220,25 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false
}
// Handle RST packets
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
conn.Lock()
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
if conn.IsTombstone() {
return true
}
conn.Unlock()
return false
}
conn.Lock()
t.updateState(conn, flags, false)
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
return true
}
conn.Lock()
t.updateState(key, conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
@ -183,18 +248,15 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
}
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetEstablished(false)
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
return
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
state := conn.State
defer func() {
if state != conn.State {
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
}
}()
switch conn.State {
switch state {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent
@ -241,6 +303,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateFinWait2:
if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
case TCPStateClosing:
@ -248,8 +313,8 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
conn.State = TCPStateTimeWait
// Keep established = false from previous state
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
case TCPStateCloseWait:
@ -260,17 +325,12 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateLastAck:
if flags&TCPAck != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
// Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, key, conn)
t.logger.Trace("TCP connection %s closed gracefully", key)
}
case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
}
}
@ -331,6 +391,12 @@ func (t *TCPTracker) cleanup() {
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.IsTombstone() {
// Clean up tombstoned connections without sending an event
delete(t.connections, key)
continue
}
var timeout time.Duration
switch {
case conn.State == TCPStateTimeWait:
@ -341,14 +407,16 @@ func (t *TCPTracker) cleanup() {
timeout = TCPHandshakeTimeout
}
lastSeen := conn.GetLastSeen()
if time.Since(lastSeen) > timeout {
if conn.timeoutExceeded(timeout) {
// Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
t.logger.Trace("Cleaned up timed-out TCP connection %s", &key)
// event already handled by state change
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, key, conn)
}
}
}
}
@ -360,10 +428,6 @@ func (t *TCPTracker) Close() {
// Clean up all remaining IPs
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
@ -381,3 +445,16 @@ func isValidFlagCombination(flags uint8) bool {
return true
}
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: key.SrcPort,
DestPort: key.DstPort,
})
}

View File

@ -9,7 +9,7 @@ import (
)
func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tt.test(t)
})
}
@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) {
}
func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP,
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) {
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) {
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections

View File

@ -5,7 +5,10 @@ import (
"sync"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@ -28,11 +31,11 @@ type UDPTracker struct {
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
flowLogger nftypes.FlowLogger
}
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
if timeout == 0 {
timeout = DefaultUDPTimeout
}
@ -43,7 +46,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
flowLogger: flowLogger,
}
go tracker.cleanupRoutine()
@ -52,32 +55,57 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
}
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
}
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
t.mutex.Lock()
t.mutex.RLock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
t.mutex.RUnlock()
conn = &UDPConnTrack{
if exists {
conn.UpdateLastSeen()
return key, true
}
return key, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
FlowId: uuid.New(),
Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New UDP connection: %v", conn)
}
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
conn.UpdateLastSeen()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, key, conn)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
@ -88,18 +116,13 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
if !exists || conn.timeoutExceeded(t.timeout) {
return false
}
if conn.timeoutExceeded(t.timeout) {
return false
}
conn.UpdateLastSeen()
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
return true
}
// cleanupRoutine periodically removes stale connections
@ -120,11 +143,10 @@ func (t *UDPTracker) cleanup() {
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
t.logger.Trace("Removed UDP connection %s (timeout)", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
}
}
@ -135,10 +157,6 @@ func (t *UDPTracker) Close() {
close(t.done)
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
@ -150,14 +168,23 @@ func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, d
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := t.connections[key]
if !exists {
return nil, false
}
return conn, true
return conn, exists
}
// Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: key.SrcPort,
DestPort: key.DstPort,
})
}

View File

@ -2,6 +2,7 @@ package conntrack
import (
"net"
"net/netip"
"testing"
"time"
@ -29,7 +30,7 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, logger)
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
@ -40,29 +41,29 @@ func TestNewUDPTracker(t *testing.T) {
}
func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
// Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Equal(srcIP))
assert.True(t, conn.DestIP.Equal(dstIP))
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort)
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
}
func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger)
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
@ -160,8 +161,8 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
logger: logger,
flowLogger: flowLogger,
}
// Start cleanup routine
@ -211,7 +212,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
@ -224,7 +225,7 @@ func BenchmarkUDPTracker(b *testing.B) {
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")

View File

@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@ -29,6 +30,7 @@ const (
type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
@ -38,7 +40,7 @@ type Forwarder struct {
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger),
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,

View File

@ -3,14 +3,21 @@ package forwarder
import (
"context"
"net"
"net/netip"
"time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
@ -20,6 +27,8 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
@ -27,6 +36,8 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id)
}()
dstIP := f.determineDialAddr(id.LocalAddress)
@ -101,9 +112,25 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return true
}
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
return true
}
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 1,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}

View File

@ -5,18 +5,25 @@ import (
"fmt"
"io"
"net"
"net/netip"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id)
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
@ -46,10 +53,10 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
f.logger.Trace("forwarder: established TCP connection %v", id)
go f.proxyTCP(id, inConn, outConn, ep)
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
@ -58,6 +65,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id)
}()
// Create context for managing the proxy goroutines
@ -88,3 +97,18 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
return
}
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 6,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}

View File

@ -5,10 +5,12 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@ -16,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@ -28,11 +31,13 @@ type udpPacketConn struct {
lastSeen atomic.Int64
cancel context.CancelFunc
ep tcpip.Endpoint
flowID uuid.UUID
}
type udpForwarder struct {
sync.RWMutex
logger *nblog.Logger
flowLogger nftypes.FlowLogger
conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
@ -44,10 +49,11 @@ type idleConn struct {
conn *udpPacketConn
}
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,
flowLogger: flowLogger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
@ -83,6 +89,21 @@ func (f *udpForwarder) Stop() {
}
}
// sendUDPEvent stores flow events for UDP connections
func (f *udpForwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 17,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
// cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute)
@ -119,6 +140,8 @@ func (f *udpForwarder) cleanup() {
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
f.sendUDPEvent(nftypes.TypeEnd, idle.conn.flowID, idle.id)
}
}
}
@ -141,10 +164,14 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
return
}
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id)
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
// TODO: Send ICMP error message
return
}
@ -157,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return
}
@ -168,6 +196,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
outConn: outConn,
cancel: connCancel,
ep: ep,
flowID: flowID,
}
pConn.updateLastSeen()
@ -182,6 +211,8 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return
}
f.udpForwarder.conns[id] = pConn
@ -206,6 +237,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id)
}()
errChan := make(chan error, 2)
@ -231,6 +264,21 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
}
}
// sendUDPEvent stores flow events for UDP connections, mirrors the TCP version
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 17, // UDP protocol number
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano())
}

View File

@ -1,4 +1,4 @@
// Package logger provides a high-performance, non-blocking logger for userspace networking
// Package log provides a high-performance, non-blocking logger for userspace networking
package log
import (
@ -13,13 +13,12 @@ import (
)
const (
maxBatchSize = 1024 * 16 // 16KB max batch size
maxMessageSize = 1024 * 2 // 2KB per message
bufferSize = 1024 * 256 // 256KB ring buffer
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
)
// Level represents log severity
type Level uint32
const (
@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
LevelTrace: "TRAC",
}
type logMessage struct {
level Level
format string
args []any
}
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
buffer *ringBuffer
msgChannel chan logMessage
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
// Reusable buffer pool for formatting messages
bufPool sync.Pool
}
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
buffer: newRingBuffer(bufferSize),
msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() interface{} {
// Pre-allocate buffer for message formatting
New: func() any {
b := make([]byte, 0, maxMessageSize)
return &b
},
},
}
logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)]
@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
return l
}
// SetLevel sets the logging level
func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
*buf = (*buf)[:0]
// Timestamp
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
// Level
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
// Message
if len(args) > 0 {
*buf = append(*buf, fmt.Sprintf(format, args...)...)
} else {
*buf = append(*buf, format...)
func (l *Logger) log(level Level, format string, args ...any) {
select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
default:
}
}
*buf = append(*buf, '\n')
}
func (l *Logger) log(level Level, format string, args ...interface{}) {
bufp := l.bufPool.Get().(*[]byte)
l.formatMessage(bufp, level, format, args...)
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
_, _ = l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
func (l *Logger) Error(format string, args ...interface{}) {
// Error logs a message at error level
func (l *Logger) Error(format string, args ...any) {
if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...)
}
}
func (l *Logger) Warn(format string, args ...interface{}) {
// Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
}
}
func (l *Logger) Info(format string, args ...interface{}) {
// Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
}
}
func (l *Logger) Debug(format string, args ...interface{}) {
// Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
}
}
func (l *Logger) Trace(format string, args ...interface{}) {
// Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
}
}
// worker periodically flushes the buffer
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05.000-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
}
*buf = append(*buf, msg...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
*buf = (*buf)[:maxMessageSize]
}
}
// processMessage handles a single log message and adds it to the buffer
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
*buffer = append(*buffer, *bufp...)
}
// flushBuffer writes the accumulated buffer to output
func (l *Logger) flushBuffer(buffer *[]byte) {
if len(*buffer) > 0 {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
}
// processBatch processes as many messages as possible without blocking
func (l *Logger) processBatch(buffer *[]byte) {
for len(*buffer) < maxBatchSize {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
default:
return
}
}
}
// handleShutdown manages the graceful shutdown sequence with timeout
func (l *Logger) handleShutdown(buffer *[]byte) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
for {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
case <-ctx.Done():
l.flushBuffer(buffer)
return
}
if len(l.msgChannel) == 0 {
l.flushBuffer(buffer)
return
}
}
}
// worker is the main goroutine that processes log messages
func (l *Logger) worker() {
defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop()
buf := make([]byte, 0, maxBatchSize)
buffer := make([]byte, 0, maxBatchSize)
for {
select {
case <-l.shutdown:
l.handleShutdown(&buffer)
return
case <-ticker.C:
// Read accumulated messages
n, _ := l.buffer.Read(buf[:cap(buf)])
if n == 0 {
continue
}
// Write batch
_, _ = l.output.Write(buf[:n])
l.flushBuffer(&buffer)
case msg := <-l.msgChannel:
l.processMessage(msg, &buffer)
l.processBatch(&buffer)
}
}
}

View File

@ -0,0 +1,121 @@
package log_test
import (
"context"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
type discard struct{}
func (d *discard) Write(p []byte) (n int, err error) {
return len(p), nil
}
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(simpleMessage)
}
})
b.Run("ConntrackMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
b.Run("ComplexMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
}
})
}
// BenchmarkLoggerParallel tests the logger under concurrent load
func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
}
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
}
}
func createTestLogger() *log.Logger {
logrusLogger := logrus.New()
logrusLogger.SetOutput(&discard{})
logrusLogger.SetLevel(logrus.TraceLevel)
return log.NewFromLogrus(logrusLogger)
}
func cleanupLogger(logger *log.Logger) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = logger.Stop(ctx)
}

View File

@ -1,85 +0,0 @@
package log
import "sync"
// ringBuffer is a simple ring buffer implementation
type ringBuffer struct {
buf []byte
size int
r, w int64 // Read and write positions
mu sync.Mutex
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (r *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if len(p) > r.size {
p = p[:r.size]
}
n = len(p)
// Write data, handling wrap-around
pos := int(r.w % int64(r.size))
writeLen := min(len(p), r.size-pos)
copy(r.buf[pos:], p[:writeLen])
// If we have more data and need to wrap around
if writeLen < len(p) {
copy(r.buf, p[writeLen:])
}
// Update write position
r.w += int64(n)
return n, nil
}
func (r *ringBuffer) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.w == r.r {
return 0, nil
}
// Calculate available data accounting for wraparound
available := int(r.w - r.r)
if available < 0 {
available += r.size
}
available = min(available, r.size)
// Limit read to buffer size
toRead := min(available, len(p))
if toRead == 0 {
return 0, nil
}
// Read data, handling wrap-around
pos := int(r.r % int64(r.size))
readLen := min(toRead, r.size-pos)
n = copy(p, r.buf[pos:pos+readLen])
// If we need more data and need to wrap around
if readLen < toRead {
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
}
// Update read position
r.r += int64(n)
return n, nil
}

View File

@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@ -96,6 +97,7 @@ type Manager struct {
tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder
logger *nblog.Logger
flowLogger nftypes.FlowLogger
}
// decoder for packages
@ -112,16 +114,16 @@ type decoder struct {
}
// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
return create(iface, nil, disableServerRoutes)
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger)
}
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
if nativeFirewall == nil {
return nil, errors.New("native firewall is nil")
}
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
if err != nil {
return nil, err
}
@ -148,7 +150,7 @@ func parseCreateEnv() (bool, bool) {
return disableConntrack, enableLocalForwarding
}
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv()
m := &Manager{
@ -174,6 +176,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
routingEnabled: false,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
}
@ -185,9 +188,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
if disableConntrack {
log.Info("conntrack is disabled")
} else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
}
// netstack needs the forwarder for local traffic
@ -304,7 +307,7 @@ func (m *Manager) initForwarder() error {
return errors.New("forwarding not supported")
}
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
if err != nil {
m.routingEnabled = false
return fmt.Errorf("create forwarder: %w", err)
@ -533,14 +536,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
// Track all protocols if stateful mode is enabled
if m.stateful {
switch d.decoded[1] {
case layers.LayerTypeUDP:
m.trackUDPOutbound(d, srcIP, dstIP)
case layers.LayerTypeTCP:
m.trackTCPOutbound(d, srcIP, dstIP)
case layers.LayerTypeICMPv4:
m.trackICMPOutbound(d, srcIP, dstIP)
}
m.trackOutbound(d, srcIP, dstIP)
}
// Process UDP hooks even if stateful mode is disabled
@ -562,17 +558,6 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
}
}
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
flags,
)
}
func getTCPFlags(tcp *layers.TCP) uint8 {
var flags uint8
if tcp.SYN {
@ -596,13 +581,34 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
m.udpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
)
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
case layers.LayerTypeICMPv4:
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq)
}
}
}
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
case layers.LayerTypeICMPv4:
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq)
}
}
}
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
@ -618,17 +624,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo
return false
}
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
m.icmpTracker.TrackOutbound(
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
)
}
}
// dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte) bool {
@ -675,6 +670,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
return m.handleNetstackLocalTraffic(packetData)
}
// track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP)
return false
}

View File

@ -158,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -203,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -251,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -450,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -577,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -668,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -787,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -875,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})

View File

@ -34,7 +34,7 @@ func TestPeerACLFiltering(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false)
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
require.NotNil(t, manager)
@ -302,7 +302,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
},
}
manager, err := Create(ifaceMock, false)
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err)
require.NotNil(tb, manager)

View File

@ -1,8 +1,10 @@
package uspfilter
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
@ -18,9 +20,11 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/netflow"
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error
@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false)
m, err := Create(ifaceMock, false, flowLogger)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
},
}
m, err := Create(ifaceMock, false)
m, err := Create(ifaceMock, false, flowLogger)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@ -116,7 +120,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false)
m, err := Create(ifaceMock, false, flowLogger)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@ -187,7 +191,7 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@ -236,7 +240,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false)
m, err := Create(ifaceMock, false, flowLogger)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@ -279,7 +283,7 @@ func TestNotMatchByIP(t *testing.T) {
},
}
m, err := Create(ifaceMock, false)
m, err := Create(ifaceMock, false, flowLogger)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@ -347,7 +351,7 @@ func TestRemovePacketHook(t *testing.T) {
}
// creating manager instance
manager, err := Create(iface, false)
manager, err := Create(iface, false, flowLogger)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
@ -393,7 +397,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
@ -401,7 +405,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() {
require.NoError(t, manager.Reset(nil))
}()
@ -479,7 +483,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false)
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
time.Sleep(time.Second)
@ -506,7 +510,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false)
}, false, flowLogger)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
@ -515,7 +519,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
@ -534,8 +538,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}()
// Set up packet parameters
srcIP := net.ParseIP("100.10.0.1")
dstIP := net.ParseIP("100.10.0.100")
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := netip.MustParseAddr("100.10.0.100")
srcPort := uint16(51334)
dstPort := uint16(53)
@ -543,8 +547,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
outboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
SrcIP: srcIP.AsSlice(),
DstIP: dstIP.AsSlice(),
Protocol: layers.IPProtocolUDP,
}
outboundUDP := &layers.UDP{
@ -573,11 +577,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
@ -585,8 +589,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
inboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: dstIP, // Original destination is now source
DstIP: srcIP, // Original source is now destination
SrcIP: dstIP.AsSlice(), // Original destination is now source
DstIP: srcIP.AsSlice(), // Original source is now destination
Protocol: layers.IPProtocolUDP,
}
inboundUDP := &layers.UDP{
@ -641,7 +645,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
// If the connection should still be valid, verify it exists
if cp.shouldAllow {
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
require.True(t, exists, "Connection should still exist during valid window")
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
"LastSeen should be updated for valid responses")

View File

@ -1,6 +1,7 @@
package acl
import (
"context"
"net"
"testing"
@ -10,9 +11,12 @@ import (
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
@ -52,7 +56,7 @@ func TestDefaultManager(t *testing.T) {
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil {
t.Errorf("create firewall: %v", err)
return
@ -346,7 +350,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil {
t.Errorf("create firewall: %v", err)
return

View File

@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
@ -29,6 +30,8 @@ import (
"github.com/netbirdio/netbird/formatter"
)
var flowLogger = netflow.NewManager(context.Background()).GetLogger()
type mocWGIface struct {
filter device.PacketFilter
}
@ -916,7 +919,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err
}
pf, err := uspfilter.Create(wgIface, false)
pf, err := uspfilter.Create(wgIface, false, flowLogger)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err

View File

@ -35,7 +35,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/netflow/types"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
@ -191,7 +191,7 @@ type Engine struct {
persistNetworkMap bool
latestNetworkMap *mgmProto.NetworkMap
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager types.FlowManager
flowManager nftypes.FlowManager
}
// Peer is an instance of the Connection Peer
@ -454,7 +454,7 @@ func (e *Engine) createFirewall() error {
}
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil
@ -721,11 +721,11 @@ func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
return e.flowManager.Update(flowConfig)
}
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*types.FlowConfig, error) {
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
if config.GetInterval() == nil {
return nil, errors.New("flow interval is nil")
}
return &types.FlowConfig{
return &nftypes.FlowConfig{
Enabled: config.GetEnabled(),
URL: config.GetUrl(),
TokenPayload: config.GetTokenPayload(),

View File

@ -7,17 +7,52 @@ import (
"github.com/google/uuid"
)
type Protocol uint8
const (
ProtocolUnknown = 0
ICMP = 1
TCP = 6
UDP = 17
)
func (p Protocol) String() string {
switch p {
case 1:
return "ICMP"
case 6:
return "TCP"
case 17:
return "UDP"
default:
return "unknown"
}
}
type Type int
const (
TypeStart = iota
TypeUnknown = iota
TypeStart
TypeEnd
)
type Direction int
func (d Direction) String() string {
switch d {
case Ingress:
return "ingress"
case Egress:
return "egress"
default:
return "unknown"
}
}
const (
Ingress = iota
DirectionUnknown = iota
Ingress
Egress
)