From ad9f044aadadf6a77facd418373e9aad32095443 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 23 Dec 2024 18:22:17 +0100 Subject: [PATCH] [client] Add stateful userspace firewall and remove egress filters (#3093) - Add stateful firewall functionality for UDP/TCP/ICMP in userspace firewalll - Removes all egress drop rules/filters, still needs refactoring so we don't add output rules to any chains/filters. - on Linux, if the OUTPUT policy is DROP then we don't do anything about it (no extra allow rules). This is up to the user, if they don't want anything leaving their machine they'll have to manage these rules explicitly. --- client/firewall/iptables/acl_linux.go | 6 - client/firewall/iptables/manager_linux.go | 14 +- client/firewall/nftables/acl_linux.go | 53 - client/firewall/uspfilter/allow_netbird.go | 20 +- .../uspfilter/allow_netbird_windows.go | 16 + client/firewall/uspfilter/conntrack/common.go | 138 +++ .../uspfilter/conntrack/common_test.go | 115 ++ client/firewall/uspfilter/conntrack/icmp.go | 170 +++ .../firewall/uspfilter/conntrack/icmp_test.go | 39 + client/firewall/uspfilter/conntrack/tcp.go | 376 +++++++ .../firewall/uspfilter/conntrack/tcp_test.go | 311 ++++++ client/firewall/uspfilter/conntrack/udp.go | 158 +++ .../firewall/uspfilter/conntrack/udp_test.go | 243 +++++ client/firewall/uspfilter/uspfilter.go | 258 ++++- .../uspfilter/uspfilter_bench_test.go | 998 ++++++++++++++++++ client/firewall/uspfilter/uspfilter_test.go | 309 +++++- 16 files changed, 3104 insertions(+), 120 deletions(-) create mode 100644 client/firewall/uspfilter/conntrack/common.go create mode 100644 client/firewall/uspfilter/conntrack/common_test.go create mode 100644 client/firewall/uspfilter/conntrack/icmp.go create mode 100644 client/firewall/uspfilter/conntrack/icmp_test.go create mode 100644 client/firewall/uspfilter/conntrack/tcp.go create mode 100644 client/firewall/uspfilter/conntrack/tcp_test.go create mode 100644 client/firewall/uspfilter/conntrack/udp.go create mode 100644 client/firewall/uspfilter/conntrack/udp_test.go create mode 100644 client/firewall/uspfilter/uspfilter_bench_test.go diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 1c0527ebc..d774f4538 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -332,18 +332,12 @@ func (m *aclManager) createDefaultChains() error { // The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule. func (m *aclManager) seedInitialEntries() { - established := getConntrackEstablished() m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules}) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...)) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index adb8f20ef..0e1e5836f 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -207,19 +207,9 @@ func (m *Manager) AllowNetbird() error { "", ) if err != nil { - return fmt.Errorf("failed to allow netbird interface traffic: %w", err) + return fmt.Errorf("allow netbird interface traffic: %w", err) } - _, err = m.AddPeerFiltering( - net.ParseIP("0.0.0.0"), - "all", - nil, - nil, - firewall.RuleDirectionOUT, - firewall.ActionAccept, - "", - "", - ) - return err + return nil } // Flush doesn't need to be implemented for this manager diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index abe890fb9..852cfec8d 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "fmt" "net" - "net/netip" "strconv" "strings" "time" @@ -28,7 +27,6 @@ const ( // filter chains contains the rules that jump to the rules chains chainNameInputFilter = "netbird-acl-input-filter" - chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" chainNamePrerouting = "netbird-rt-prerouting" @@ -441,18 +439,6 @@ func (m *AclManager) createDefaultChains() (err error) { return err } - // netbird-acl-output-filter - // type filter hook output priority filter; policy accept; - chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) - m.addFwdAllow(chain, expr.MetaKeyOIFNAME) - m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules - m.addDropExpressions(chain, expr.MetaKeyOIFNAME) - err = m.rConn.Flush() - if err != nil { - log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err) - return err - } - // netbird-acl-forward-filter chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd @@ -619,45 +605,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met return nil } -func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - dstOp := expr.CmpOpNeq - expressions := []expr.Any{ - &expr.Meta{Key: iifname, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: dstOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: expressions, - }) -} - func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index cefc81a3c..cc0792255 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -2,7 +2,10 @@ package uspfilter -import "github.com/netbirdio/netbird/client/internal/statemanager" +import ( + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/internal/statemanager" +) // Reset firewall to the default state func (m *Manager) Reset(stateManager *statemanager.Manager) error { @@ -12,6 +15,21 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + } + + if m.tcpTracker != nil { + m.tcpTracker.Close() + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + } + if m.nativeFirewall != nil { return m.nativeFirewall.Reset(stateManager) } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index d3732301e..0d55d6268 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -26,6 +27,21 @@ func (m *Manager) Reset(*statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + } + + if m.tcpTracker != nil { + m.tcpTracker.Close() + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + } + if !isWindowsFirewallReachable() { return nil } diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go new file mode 100644 index 000000000..a4b1971bf --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common.go @@ -0,0 +1,138 @@ +// common.go +package conntrack + +import ( + "net" + "sync" + "sync/atomic" + "time" +) + +// BaseConnTrack provides common fields and locking for all connection types +type BaseConnTrack struct { + sync.RWMutex + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + lastSeen atomic.Int64 // Unix nano for atomic access + established atomic.Bool +} + +// these small methods will be inlined by the compiler + +// UpdateLastSeen safely updates the last seen timestamp +func (b *BaseConnTrack) UpdateLastSeen() { + b.lastSeen.Store(time.Now().UnixNano()) +} + +// IsEstablished safely checks if connection is established +func (b *BaseConnTrack) IsEstablished() bool { + return b.established.Load() +} + +// SetEstablished safely sets the established state +func (b *BaseConnTrack) SetEstablished(state bool) { + b.established.Store(state) +} + +// GetLastSeen safely gets the last seen timestamp +func (b *BaseConnTrack) GetLastSeen() time.Time { + return time.Unix(0, b.lastSeen.Load()) +} + +// timeoutExceeded checks if the connection has exceeded the given timeout +func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool { + lastSeen := time.Unix(0, b.lastSeen.Load()) + return time.Since(lastSeen) > timeout +} + +// IPAddr is a fixed-size IP address to avoid allocations +type IPAddr [16]byte + +// MakeIPAddr creates an IPAddr from net.IP +func MakeIPAddr(ip net.IP) (addr IPAddr) { + // Optimization: check for v4 first as it's more common + if ip4 := ip.To4(); ip4 != nil { + copy(addr[12:], ip4) + } else { + copy(addr[:], ip.To16()) + } + return addr +} + +// ConnKey uniquely identifies a connection +type ConnKey struct { + SrcIP IPAddr + DstIP IPAddr + SrcPort uint16 + DstPort uint16 +} + +// makeConnKey creates a connection key +func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { + return ConnKey{ + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + } +} + +// ValidateIPs checks if IPs match without allocation +func ValidateIPs(connIP IPAddr, pktIP net.IP) bool { + if ip4 := pktIP.To4(); ip4 != nil { + // Compare IPv4 addresses (last 4 bytes) + for i := 0; i < 4; i++ { + if connIP[12+i] != ip4[i] { + return false + } + } + return true + } + // Compare full IPv6 addresses + ip6 := pktIP.To16() + for i := 0; i < 16; i++ { + if connIP[i] != ip6[i] { + return false + } + } + return true +} + +// PreallocatedIPs is a pool of IP byte slices to reduce allocations +type PreallocatedIPs struct { + sync.Pool +} + +// NewPreallocatedIPs creates a new IP pool +func NewPreallocatedIPs() *PreallocatedIPs { + return &PreallocatedIPs{ + Pool: sync.Pool{ + New: func() interface{} { + ip := make(net.IP, 16) + return &ip + }, + }, + } +} + +// Get retrieves an IP from the pool +func (p *PreallocatedIPs) Get() net.IP { + return *p.Pool.Get().(*net.IP) +} + +// Put returns an IP to the pool +func (p *PreallocatedIPs) Put(ip net.IP) { + p.Pool.Put(&ip) +} + +// copyIP copies an IP address efficiently +func copyIP(dst, src net.IP) { + if len(src) == 16 { + copy(dst, src) + } else { + // Handle IPv4 + copy(dst[12:], src.To4()) + } +} diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go new file mode 100644 index 000000000..72d006def --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -0,0 +1,115 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkIPOperations(b *testing.B) { + b.Run("MakeIPAddr", func(b *testing.B) { + ip := net.ParseIP("192.168.1.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = MakeIPAddr(ip) + } + }) + + b.Run("ValidateIPs", func(b *testing.B) { + ip1 := net.ParseIP("192.168.1.1") + ip2 := net.ParseIP("192.168.1.1") + addr := MakeIPAddr(ip1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ValidateIPs(addr, ip2) + } + }) + + b.Run("IPPool", func(b *testing.B) { + pool := NewPreallocatedIPs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ip := pool.Get() + pool.Put(ip) + } + }) + +} +func BenchmarkAtomicOperations(b *testing.B) { + conn := &BaseConnTrack{} + b.Run("UpdateLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.UpdateLastSeen() + } + }) + + b.Run("IsEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.IsEstablished() + } + }) + + b.Run("SetEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.SetEstablished(i%2 == 0) + } + }) + + b.Run("GetLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.GetLastSeen() + } + }) +} + +// Memory pressure tests +func BenchmarkMemoryPressure(b *testing.B) { + b.Run("TCPHighLoad", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) + } + } + }) + + b.Run("UDPHighLoad", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) + } + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go new file mode 100644 index 000000000..e0a971678 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -0,0 +1,170 @@ +package conntrack + +import ( + "net" + "sync" + "time" + + "github.com/google/gopacket/layers" +) + +const ( + // DefaultICMPTimeout is the default timeout for ICMP connections + DefaultICMPTimeout = 30 * time.Second + // ICMPCleanupInterval is how often we check for stale ICMP connections + ICMPCleanupInterval = 15 * time.Second +) + +// ICMPConnKey uniquely identifies an ICMP connection +type ICMPConnKey struct { + // Supports both IPv4 and IPv6 + SrcIP [16]byte + DstIP [16]byte + Sequence uint16 // ICMP sequence number + ID uint16 // ICMP identifier +} + +// ICMPConnTrack represents an ICMP connection state +type ICMPConnTrack struct { + BaseConnTrack + Sequence uint16 + ID uint16 +} + +// ICMPTracker manages ICMP connection states +type ICMPTracker struct { + connections map[ICMPConnKey]*ICMPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} + ipPool *PreallocatedIPs +} + +// NewICMPTracker creates a new ICMP connection tracker +func NewICMPTracker(timeout time.Duration) *ICMPTracker { + if timeout == 0 { + timeout = DefaultICMPTimeout + } + + tracker := &ICMPTracker{ + connections: make(map[ICMPConnKey]*ICMPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(ICMPCleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound ICMP Echo Request +func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { + key := makeICMPKey(srcIP, dstIP, id, seq) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &ICMPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + }, + ID: id, + Sequence: seq, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn + } + t.mutex.Unlock() + + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request +func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { + switch icmpType { + case uint8(layers.ICMPv4TypeDestinationUnreachable), + uint8(layers.ICMPv4TypeTimeExceeded): + return true + case uint8(layers.ICMPv4TypeEchoReply): + // continue processing + default: + return false + } + + key := makeICMPKey(dstIP, srcIP, id, seq) + + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + if conn.timeoutExceeded(t.timeout) { + return false + } + + return conn.IsEstablished() && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && + conn.ID == id && + conn.Sequence == seq +} + +func (t *ICMPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} +func (t *ICMPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *ICMPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +// makeICMPKey creates an ICMP connection key +func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { + return ICMPConnKey{ + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), + ID: id, + Sequence: seq, + } +} diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go new file mode 100644 index 000000000..21176e719 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -0,0 +1,39 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkICMPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go new file mode 100644 index 000000000..e8d20f41c --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -0,0 +1,376 @@ +package conntrack + +// TODO: Send RST packets for invalid/timed-out connections + +import ( + "net" + "sync" + "time" +) + +const ( + // MSL (Maximum Segment Lifetime) is typically 2 minutes + MSL = 2 * time.Minute + // TimeWaitTimeout (TIME-WAIT) should last 2*MSL + TimeWaitTimeout = 2 * MSL +) + +const ( + TCPSyn uint8 = 0x02 + TCPAck uint8 = 0x10 + TCPFin uint8 = 0x01 + TCPRst uint8 = 0x04 + TCPPush uint8 = 0x08 + TCPUrg uint8 = 0x20 +) + +const ( + // DefaultTCPTimeout is the default timeout for established TCP connections + DefaultTCPTimeout = 3 * time.Hour + // TCPHandshakeTimeout is timeout for TCP handshake completion + TCPHandshakeTimeout = 60 * time.Second + // TCPCleanupInterval is how often we check for stale connections + TCPCleanupInterval = 5 * time.Minute +) + +// TCPState represents the state of a TCP connection +type TCPState int + +const ( + TCPStateNew TCPState = iota + TCPStateSynSent + TCPStateSynReceived + TCPStateEstablished + TCPStateFinWait1 + TCPStateFinWait2 + TCPStateClosing + TCPStateTimeWait + TCPStateCloseWait + TCPStateLastAck + TCPStateClosed +) + +// TCPConnKey uniquely identifies a TCP connection +type TCPConnKey struct { + SrcIP [16]byte + DstIP [16]byte + SrcPort uint16 + DstPort uint16 +} + +// TCPConnTrack represents a TCP connection state +type TCPConnTrack struct { + BaseConnTrack + State TCPState +} + +// TCPTracker manages TCP connection states +type TCPTracker struct { + connections map[ConnKey]*TCPConnTrack + mutex sync.RWMutex + cleanupTicker *time.Ticker + done chan struct{} + timeout time.Duration + ipPool *PreallocatedIPs +} + +// NewTCPTracker creates a new TCP connection tracker +func NewTCPTracker(timeout time.Duration) *TCPTracker { + tracker := &TCPTracker{ + connections: make(map[ConnKey]*TCPConnTrack), + cleanupTicker: time.NewTicker(TCPCleanupInterval), + done: make(chan struct{}), + timeout: timeout, + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound processes an outbound TCP packet and updates connection state +func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { + // Create key before lock + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &TCPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + State: TCPStateNew, + } + conn.lastSeen.Store(now) + conn.established.Store(false) + t.connections[key] = conn + } + t.mutex.Unlock() + + // Lock individual connection for state update + conn.Lock() + t.updateState(conn, flags, true) + conn.Unlock() + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound TCP packet matches a tracked connection +func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { + if !isValidFlagCombination(flags) { + return false + } + + // Handle new SYN packets + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.Lock() + if _, exists := t.connections[key]; !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, dstIP) + copyIP(dstIPCopy, srcIP) + + conn := &TCPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: dstPort, + DestPort: srcPort, + }, + State: TCPStateSynReceived, + } + conn.lastSeen.Store(time.Now().UnixNano()) + conn.established.Store(false) + t.connections[key] = conn + } + t.mutex.Unlock() + return true + } + + // Look up existing connection + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + // Handle RST packets + if flags&TCPRst != 0 { + conn.Lock() + isEstablished := conn.IsEstablished() + if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { + conn.State = TCPStateClosed + conn.SetEstablished(false) + conn.Unlock() + return true + } + conn.Unlock() + return false + } + + // Update state + conn.Lock() + t.updateState(conn, flags, false) + conn.UpdateLastSeen() + isEstablished := conn.IsEstablished() + isValidState := t.isValidStateForFlags(conn.State, flags) + conn.Unlock() + + return isEstablished || isValidState +} + +// updateState updates the TCP connection state based on flags +func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) { + // Handle RST flag specially - it always causes transition to closed + if flags&TCPRst != 0 { + conn.State = TCPStateClosed + conn.SetEstablished(false) + return + } + + switch conn.State { + case TCPStateNew: + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + conn.State = TCPStateSynSent + } + + case TCPStateSynSent: + if flags&TCPSyn != 0 && flags&TCPAck != 0 { + if isOutbound { + conn.State = TCPStateSynReceived + } else { + // Simultaneous open + conn.State = TCPStateEstablished + conn.SetEstablished(true) + } + } + + case TCPStateSynReceived: + if flags&TCPAck != 0 && flags&TCPSyn == 0 { + conn.State = TCPStateEstablished + conn.SetEstablished(true) + } + + case TCPStateEstablished: + if flags&TCPFin != 0 { + if isOutbound { + conn.State = TCPStateFinWait1 + } else { + conn.State = TCPStateCloseWait + } + conn.SetEstablished(false) + } + + case TCPStateFinWait1: + switch { + case flags&TCPFin != 0 && flags&TCPAck != 0: + // Simultaneous close - both sides sent FIN + conn.State = TCPStateClosing + case flags&TCPFin != 0: + conn.State = TCPStateFinWait2 + case flags&TCPAck != 0: + conn.State = TCPStateFinWait2 + } + + case TCPStateFinWait2: + if flags&TCPFin != 0 { + conn.State = TCPStateTimeWait + } + + case TCPStateClosing: + if flags&TCPAck != 0 { + conn.State = TCPStateTimeWait + // Keep established = false from previous state + } + + case TCPStateCloseWait: + if flags&TCPFin != 0 { + conn.State = TCPStateLastAck + } + + case TCPStateLastAck: + if flags&TCPAck != 0 { + conn.State = TCPStateClosed + } + + case TCPStateTimeWait: + // Stay in TIME-WAIT for 2MSL before transitioning to closed + // This is handled by the cleanup routine + } +} + +// isValidStateForFlags checks if the TCP flags are valid for the current connection state +func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { + if !isValidFlagCombination(flags) { + return false + } + + switch state { + case TCPStateNew: + return flags&TCPSyn != 0 && flags&TCPAck == 0 + case TCPStateSynSent: + return flags&TCPSyn != 0 && flags&TCPAck != 0 + case TCPStateSynReceived: + return flags&TCPAck != 0 + case TCPStateEstablished: + if flags&TCPRst != 0 { + return true + } + return flags&TCPAck != 0 + case TCPStateFinWait1: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateFinWait2: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateClosing: + // In CLOSING state, we should accept the final ACK + return flags&TCPAck != 0 + case TCPStateTimeWait: + // In TIME_WAIT, we might see retransmissions + return flags&TCPAck != 0 + case TCPStateCloseWait: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateLastAck: + return flags&TCPAck != 0 + } + return false +} + +func (t *TCPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *TCPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + var timeout time.Duration + switch { + case conn.State == TCPStateTimeWait: + timeout = TimeWaitTimeout + case conn.IsEstablished(): + timeout = t.timeout + default: + timeout = TCPHandshakeTimeout + } + + lastSeen := conn.GetLastSeen() + if time.Since(lastSeen) > timeout { + // Return IPs to pool + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *TCPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + // Clean up all remaining IPs + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +func isValidFlagCombination(flags uint8) bool { + // Invalid: SYN+FIN + if flags&TCPSyn != 0 && flags&TCPFin != 0 { + return false + } + + // Invalid: RST with SYN or FIN + if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) { + return false + } + + return true +} diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go new file mode 100644 index 000000000..3933c8889 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -0,0 +1,311 @@ +package conntrack + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTCPStateMachine(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("100.64.0.1") + dstIP := net.ParseIP("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + t.Run("Security Tests", func(t *testing.T) { + tests := []struct { + name string + flags uint8 + wantDrop bool + desc string + }{ + { + name: "Block unsolicited SYN-ACK", + flags: TCPSyn | TCPAck, + wantDrop: true, + desc: "Should block SYN-ACK without prior SYN", + }, + { + name: "Block invalid SYN-FIN", + flags: TCPSyn | TCPFin, + wantDrop: true, + desc: "Should block invalid SYN-FIN combination", + }, + { + name: "Block unsolicited RST", + flags: TCPRst, + wantDrop: true, + desc: "Should block RST without connection", + }, + { + name: "Block unsolicited ACK", + flags: TCPAck, + wantDrop: true, + desc: "Should block ACK without connection", + }, + { + name: "Block data without connection", + flags: TCPAck | TCPPush, + wantDrop: true, + desc: "Should block data without established connection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) + require.Equal(t, !tt.wantDrop, isValid, tt.desc) + }) + } + }) + + t.Run("Connection Flow Tests", func(t *testing.T) { + tests := []struct { + name string + test func(*testing.T) + desc string + }{ + { + name: "Normal Handshake", + test: func(t *testing.T) { + t.Helper() + + // Send initial SYN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + + // Receive SYN-ACK + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + require.True(t, valid, "SYN-ACK should be allowed") + + // Send ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + + // Test data transfer + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + require.True(t, valid, "Data should be allowed after handshake") + }, + }, + { + name: "Normal Close", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Send FIN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + + // Receive ACK for FIN + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + require.True(t, valid, "ACK for FIN should be allowed") + + // Receive FIN from other side + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + require.True(t, valid, "FIN should be allowed") + + // Send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + }, + }, + { + name: "RST During Connection", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Receive RST + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + require.True(t, valid, "RST should be allowed for established connection") + + // Verify connection is closed + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + t.Helper() + + require.False(t, valid, "Data should be blocked after RST") + }, + }, + { + name: "Simultaneous Close", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Both sides send FIN+ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + require.True(t, valid, "Simultaneous FIN should be allowed") + + // Both sides send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + require.True(t, valid, "Final ACKs should be allowed") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + + tracker = NewTCPTracker(DefaultTCPTimeout) + tt.test(t) + }) + } + }) +} + +func TestRSTHandling(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("100.64.0.1") + dstIP := net.ParseIP("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + tests := []struct { + name string + setupState func() + sendRST func() + wantValid bool + desc string + }{ + { + name: "RST in established", + setupState: func() { + // Establish connection first + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + }, + sendRST: func() { + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + }, + wantValid: true, + desc: "Should accept RST for established connection", + }, + { + name: "RST without connection", + setupState: func() {}, + sendRST: func() { + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + }, + wantValid: false, + desc: "Should reject RST without connection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + tt.sendRST() + + // Verify connection state is as expected + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn := tracker.connections[key] + if tt.wantValid { + require.NotNil(t, conn) + require.Equal(t, TCPStateClosed, conn.State) + require.False(t, conn.IsEstablished()) + } + }) + } +} + +// Helper to establish a TCP connection +func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { + t.Helper() + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + require.True(t, valid, "SYN-ACK should be allowed") + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) +} + +func BenchmarkTCPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) + } + }) + + b.Run("ConcurrentAccess", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + if i%2 == 0 { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } else { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) + } + i++ + } + }) + }) +} + +// Benchmark connection cleanup +func BenchmarkCleanup(b *testing.B) { + b.Run("TCPCleanup", func(b *testing.B) { + tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing + defer tracker.Close() + + // Pre-populate with expired connections + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + for i := 0; i < 10000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + // Wait for connections to expire + time.Sleep(200 * time.Millisecond) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.cleanup() + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go new file mode 100644 index 000000000..a969a4e84 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -0,0 +1,158 @@ +package conntrack + +import ( + "net" + "sync" + "time" +) + +const ( + // DefaultUDPTimeout is the default timeout for UDP connections + DefaultUDPTimeout = 30 * time.Second + // UDPCleanupInterval is how often we check for stale connections + UDPCleanupInterval = 15 * time.Second +) + +// UDPConnTrack represents a UDP connection state +type UDPConnTrack struct { + BaseConnTrack +} + +// UDPTracker manages UDP connection states +type UDPTracker struct { + connections map[ConnKey]*UDPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} + ipPool *PreallocatedIPs +} + +// NewUDPTracker creates a new UDP connection tracker +func NewUDPTracker(timeout time.Duration) *UDPTracker { + if timeout == 0 { + timeout = DefaultUDPTimeout + } + + tracker := &UDPTracker{ + connections: make(map[ConnKey]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(UDPCleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound UDP connection +func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &UDPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn + } + t.mutex.Unlock() + + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound packet matches a tracked connection +func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + if conn.timeoutExceeded(t.timeout) { + return false + } + + return conn.IsEstablished() && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && + conn.DestPort == srcPort && + conn.SourcePort == dstPort +} + +// cleanupRoutine periodically removes stale connections +func (t *UDPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *UDPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *UDPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +// GetConnection safely retrieves a connection state +func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn, exists := t.connections[key] + if !exists { + return nil, false + } + + return conn, true +} + +// Timeout returns the configured timeout duration for the tracker +func (t *UDPTracker) Timeout() time.Duration { + return t.timeout +} diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go new file mode 100644 index 000000000..671721890 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -0,0 +1,243 @@ +package conntrack + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewUDPTracker(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + wantTimeout time.Duration + }{ + { + name: "with custom timeout", + timeout: 1 * time.Minute, + wantTimeout: 1 * time.Minute, + }, + { + name: "with zero timeout uses default", + timeout: 0, + wantTimeout: DefaultUDPTimeout, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tracker := NewUDPTracker(tt.timeout) + assert.NotNil(t, tracker) + assert.Equal(t, tt.wantTimeout, tracker.timeout) + assert.NotNil(t, tracker.connections) + assert.NotNil(t, tracker.cleanupTicker) + assert.NotNil(t, tracker.done) + }) + } +} + +func TestUDPTracker_TrackOutbound(t *testing.T) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + // Verify connection was tracked + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn, exists := tracker.connections[key] + require.True(t, exists) + assert.True(t, conn.SourceIP.Equal(srcIP)) + assert.True(t, conn.DestIP.Equal(dstIP)) + assert.Equal(t, srcPort, conn.SourcePort) + assert.Equal(t, dstPort, conn.DestPort) + assert.True(t, conn.IsEstablished()) + assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) +} + +func TestUDPTracker_IsValidInbound(t *testing.T) { + tracker := NewUDPTracker(1 * time.Second) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + // Track outbound connection + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + tests := []struct { + name string + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + sleep time.Duration + want bool + }{ + { + name: "valid inbound response", + srcIP: dstIP, // Original destination is now source + dstIP: srcIP, // Original source is now destination + srcPort: dstPort, // Original destination port is now source + dstPort: srcPort, // Original source port is now destination + sleep: 0, + want: true, + }, + { + name: "invalid source IP", + srcIP: net.ParseIP("192.168.1.4"), + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination IP", + srcIP: dstIP, + dstIP: net.ParseIP("192.168.1.4"), + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid source port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: 54321, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: 54321, + sleep: 0, + want: false, + }, + { + name: "expired connection", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 2 * time.Second, // Longer than tracker timeout + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.sleep > 0 { + time.Sleep(tt.sleep) + } + got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestUDPTracker_Cleanup(t *testing.T) { + // Use shorter intervals for testing + timeout := 50 * time.Millisecond + cleanupInterval := 25 * time.Millisecond + + // Create tracker with custom cleanup interval + tracker := &UDPTracker{ + connections: make(map[ConnKey]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(cleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + // Start cleanup routine + go tracker.cleanupRoutine() + + // Add some connections + connections := []struct { + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + }{ + { + srcIP: net.ParseIP("192.168.1.2"), + dstIP: net.ParseIP("192.168.1.3"), + srcPort: 12345, + dstPort: 53, + }, + { + srcIP: net.ParseIP("192.168.1.4"), + dstIP: net.ParseIP("192.168.1.5"), + srcPort: 12346, + dstPort: 53, + }, + } + + for _, conn := range connections { + tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) + } + + // Verify initial connections + assert.Len(t, tracker.connections, 2) + + // Wait for connection timeout and cleanup interval + time.Sleep(timeout + 2*cleanupInterval) + + tracker.mutex.RLock() + connCount := len(tracker.connections) + tracker.mutex.RUnlock() + + // Verify connections were cleaned up + assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up") + + // Properly close the tracker + tracker.Close() +} + +func BenchmarkUDPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) + } + }) +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index fb726395b..24cfd6e96 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -4,6 +4,8 @@ import ( "fmt" "net" "net/netip" + "os" + "strconv" "sync" "github.com/google/gopacket" @@ -12,6 +14,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -19,6 +22,8 @@ import ( const layerTypeAll = 0 +const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" + var ( errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") ) @@ -42,6 +47,11 @@ type Manager struct { nativeFirewall firewall.Manager mutex sync.RWMutex + + stateful bool + udpTracker *conntrack.UDPTracker + icmpTracker *conntrack.ICMPTracker + tcpTracker *conntrack.TCPTracker } // decoder for packages @@ -73,6 +83,8 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager } func create(iface IFaceMapper) (*Manager, error) { + disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) + m := &Manager{ decoders: sync.Pool{ New: func() any { @@ -90,6 +102,16 @@ func create(iface IFaceMapper) (*Manager, error) { outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), wgIface: iface, + stateful: !disableConntrack, + } + + // Only initialize trackers if stateful mode is enabled + if disableConntrack { + log.Info("conntrack is disabled") + } else { + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) } if err := iface.SetFilter(m); err != nil { @@ -249,16 +271,16 @@ func (m *Manager) Flush() error { return nil } // DropOutgoing filter outgoing packets func (m *Manager) DropOutgoing(packetData []byte) bool { - return m.dropFilter(packetData, m.outgoingRules, false) + return m.processOutgoingHooks(packetData) } // DropIncoming filter incoming packets func (m *Manager) DropIncoming(packetData []byte) bool { - return m.dropFilter(packetData, m.incomingRules, true) + return m.dropFilter(packetData, m.incomingRules) } -// dropFilter implements same logic for booth direction of the traffic -func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool { +// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP +func (m *Manager) processOutgoingHooks(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -266,61 +288,213 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco defer m.decoders.Put(d) if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - log.Tracef("couldn't decode layer, err: %s", err) - return true + return false } if len(d.decoded) < 2 { - log.Tracef("not enough levels in network packet") + return false + } + + srcIP, dstIP := m.extractIPs(d) + if srcIP == nil { + return false + } + + // Always process UDP hooks + if d.decoded[1] == layers.LayerTypeUDP { + // Track UDP state only if enabled + if m.stateful { + m.trackUDPOutbound(d, srcIP, dstIP) + } + return m.checkUDPHooks(d, dstIP, packetData) + } + + // Track other protocols only if stateful mode is enabled + if m.stateful { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.trackTCPOutbound(d, srcIP, dstIP) + case layers.LayerTypeICMPv4: + m.trackICMPOutbound(d, srcIP, dstIP) + } + } + + return false +} + +func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + return d.ip4.SrcIP, d.ip4.DstIP + case layers.LayerTypeIPv6: + return d.ip6.SrcIP, d.ip6.DstIP + default: + return nil, nil + } +} + +func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) { + flags := getTCPFlags(&d.tcp) + m.tcpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + flags, + ) +} + +func getTCPFlags(tcp *layers.TCP) uint8 { + var flags uint8 + if tcp.SYN { + flags |= conntrack.TCPSyn + } + if tcp.ACK { + flags |= conntrack.TCPAck + } + if tcp.FIN { + flags |= conntrack.TCPFin + } + if tcp.RST { + flags |= conntrack.TCPRst + } + if tcp.PSH { + flags |= conntrack.TCPPush + } + if tcp.URG { + flags |= conntrack.TCPUrg + } + return flags +} + +func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) { + m.udpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) +} + +func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { + for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { + if rules, exists := m.outgoingRules[ipKey]; exists { + for _, rule := range rules { + if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { + return rule.udpHook(packetData) + } + } + } + } + return false +} + +func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { + if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { + m.icmpTracker.TrackOutbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + ) + } +} + +// dropFilter implements filtering logic for incoming packets +func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if !m.isValidPacket(d, packetData) { return true } - ipLayer := d.decoded[0] - - switch ipLayer { - case layers.LayerTypeIPv4: - if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) { - return false - } - case layers.LayerTypeIPv6: - if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) { - return false - } - default: + srcIP, dstIP := m.extractIPs(d) + if srcIP == nil { log.Errorf("unknown layer: %v", d.decoded[0]) return true } - var ip net.IP - switch ipLayer { - case layers.LayerTypeIPv4: - if isIncomingPacket { - ip = d.ip4.SrcIP - } else { - ip = d.ip4.DstIP - } - case layers.LayerTypeIPv6: - if isIncomingPacket { - ip = d.ip6.SrcIP - } else { - ip = d.ip6.DstIP - } + if !m.isWireguardTraffic(srcIP, dstIP) { + return false } - filter, ok := validateRule(ip, packetData, rules[ip.String()], d) - if ok { - return filter + // Check connection state only if enabled + if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { + return false } - filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d) - if ok { - return filter + + return m.applyRules(srcIP, packetData, rules, d) +} + +func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + log.Tracef("couldn't decode layer, err: %s", err) + return false } - filter, ok = validateRule(ip, packetData, rules["::"], d) - if ok { + + if len(d.decoded) < 2 { + log.Tracef("not enough levels in network packet") + return false + } + return true +} + +func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool { + return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP) +} + +func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { + switch d.decoded[1] { + case layers.LayerTypeTCP: + return m.tcpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + getTCPFlags(&d.tcp), + ) + + case layers.LayerTypeUDP: + return m.udpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) + + case layers.LayerTypeICMPv4: + return m.icmpTracker.IsValidInbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + d.icmp4.TypeCode.Type(), + ) + + // TODO: ICMPv6 + } + + return false +} + +func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { + if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { return filter } - // default policy is DROP ALL + if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { + return filter + } + + if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok { + return filter + } + + // Default policy: DROP ALL return true } diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go new file mode 100644 index 000000000..3c661e71c --- /dev/null +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -0,0 +1,998 @@ +package uspfilter + +import ( + "fmt" + "math/rand" + "net" + "os" + "strings" + "testing" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/iface/device" +) + +// generateRandomIPs generates n different random IPs in the 100.64.0.0/10 range +func generateRandomIPs(n int) []net.IP { + ips := make([]net.IP, n) + seen := make(map[string]bool) + + for i := 0; i < n; { + ip := make(net.IP, 4) + ip[0] = 100 + ip[1] = byte(64 + rand.Intn(63)) // 64-126 + ip[2] = byte(rand.Intn(256)) + ip[3] = byte(1 + rand.Intn(254)) // avoid .0 and .255 + + key := ip.String() + if !seen[key] { + ips[i] = ip + seen[key] = true + i++ + } + } + return ips +} + +func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte { + b.Helper() + + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: protocol, + } + + var transportLayer gopacket.SerializableLayer + switch protocol { + case layers.IPProtocolTCP: + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + } + require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) + transportLayer = tcp + case layers.IPProtocolUDP: + udp := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + require.NoError(b, udp.SetNetworkLayerForChecksum(ipv4)) + transportLayer = udp + } + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test")) + require.NoError(b, err) + return buf.Bytes() +} + +// BenchmarkCoreFiltering focuses on the essential performance comparisons between +// stateful and stateless filtering approaches +func BenchmarkCoreFiltering(b *testing.B) { + scenarios := []struct { + name string + stateful bool + setupFunc func(*Manager) + desc string + }{ + { + name: "stateless_single_allow_all", + stateful: false, + setupFunc: func(m *Manager) { + // Single rule allowing all traffic + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "allow all") + require.NoError(b, err) + }, + desc: "Baseline: Single 'allow all' rule without connection tracking", + }, + { + name: "stateful_no_rules", + stateful: true, + setupFunc: func(m *Manager) { + // No explicit rules - rely purely on connection tracking + }, + desc: "Pure connection tracking without any rules", + }, + { + name: "stateless_explicit_return", + stateful: false, + setupFunc: func(m *Manager) { + // Add explicit rules matching return traffic pattern + for i := 0; i < 1000; i++ { // Simulate realistic ruleset size + ip := generateRandomIPs(1)[0] + _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, + &fw.Port{Values: []int{1024 + i}}, + &fw.Port{Values: []int{80}}, + fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return") + require.NoError(b, err) + } + }, + desc: "Explicit rules matching return traffic patterns without state", + }, + { + name: "stateful_with_established", + stateful: true, + setupFunc: func(m *Manager) { + // Add some basic rules but rely on state for established connections + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, + fw.RuleDirectionIN, fw.ActionDrop, "", "default drop") + require.NoError(b, err) + }, + desc: "Connection tracking with established connections", + }, + } + + // Test both TCP and UDP + protocols := []struct { + name string + proto layers.IPProtocol + }{ + {"TCP", layers.IPProtocolTCP}, + {"UDP", layers.IPProtocolUDP}, + } + + for _, sc := range scenarios { + for _, proto := range protocols { + b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) + } else { + require.NoError(b, os.Setenv("NB_CONNTRACK_TIMEOUT", "1m")) + } + + // Create manager and basic setup + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + // Apply scenario-specific setup + sc.setupFunc(manager) + + // Generate test packets + srcIP := generateRandomIPs(1)[0] + dstIP := generateRandomIPs(1)[0] + srcPort := uint16(1024 + b.N%60000) + dstPort := uint16(80) + + outbound := generatePacket(b, srcIP, dstIP, srcPort, dstPort, proto.proto) + inbound := generatePacket(b, dstIP, srcIP, dstPort, srcPort, proto.proto) + + // For stateful scenarios, establish the connection + if sc.stateful { + manager.processOutgoingHooks(outbound) + } + + // Measure inbound packet processing + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } + } +} + +// BenchmarkStateScaling measures how performance scales with connection table size +func BenchmarkStateScaling(b *testing.B) { + connCounts := []int{100, 1000, 10000, 100000} + + for _, count := range connCounts { + b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + // Pre-populate connection table + srcIPs := generateRandomIPs(count) + dstIPs := generateRandomIPs(count) + for i := 0; i < count; i++ { + outbound := generatePacket(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, layers.IPProtocolTCP) + manager.processOutgoingHooks(outbound) + } + + // Test packet + testOut := generatePacket(b, srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP) + testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) + + // First establish our test connection + manager.processOutgoingHooks(testOut) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(testIn, manager.incomingRules) + } + }) + } +} + +// BenchmarkEstablishmentOverhead measures the overhead of connection establishment +func BenchmarkEstablishmentOverhead(b *testing.B) { + scenarios := []struct { + name string + established bool + }{ + {"established", true}, + {"new", false}, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + srcIP := generateRandomIPs(1)[0] + dstIP := generateRandomIPs(1)[0] + outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) + inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + + if sc.established { + manager.processOutgoingHooks(outbound) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } +} + +// BenchmarkRoutedNetworkReturn compares approaches for handling routed network return traffic +func BenchmarkRoutedNetworkReturn(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + state string // "new", "established", "post_handshake" (TCP only) + setupFunc func(*Manager) + genPackets func(net.IP, net.IP) ([]byte, []byte) // generates appropriate packets for the scenario + desc string + }{ + { + name: "allow_non_wg_tcp_new", + proto: layers.IPProtocolTCP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + }, + desc: "Allow non-WG: TCP new connection", + }, + { + name: "allow_non_wg_tcp_established", + proto: layers.IPProtocolTCP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate packets with ACK flag for established connection + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + }, + desc: "Allow non-WG: TCP established connection", + }, + { + name: "allow_non_wg_udp_new", + proto: layers.IPProtocolUDP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Allow non-WG: UDP new connection", + }, + { + name: "allow_non_wg_udp_established", + proto: layers.IPProtocolUDP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Allow non-WG: UDP established connection", + }, + { + name: "stateful_tcp_new", + proto: layers.IPProtocolTCP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + }, + desc: "Stateful: TCP new connection", + }, + { + name: "stateful_tcp_established", + proto: layers.IPProtocolTCP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate established TCP packets (ACK flag) + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + }, + desc: "Stateful: TCP established connection", + }, + { + name: "stateful_tcp_post_handshake", + proto: layers.IPProtocolTCP, + state: "post_handshake", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate packets with PSH+ACK flags for data transfer + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck)) + }, + desc: "Stateful: TCP post-handshake data transfer", + }, + { + name: "stateful_udp_new", + proto: layers.IPProtocolUDP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Stateful: UDP new connection", + }, + { + name: "stateful_udp_established", + proto: layers.IPProtocolUDP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Stateful: UDP established connection", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + // Setup scenario + sc.setupFunc(manager) + + // Use IPs outside WG range for routed network simulation + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("8.8.8.8") + outbound, inbound := sc.genPackets(srcIP, dstIP) + + // For stateful cases and established connections + if !strings.Contains(sc.name, "allow_non_wg") || + (strings.Contains(sc.state, "established") || sc.state == "post_handshake") { + manager.processOutgoingHooks(outbound) + + // For TCP post-handshake, simulate full handshake + if sc.state == "post_handshake" { + // SYN + syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + // SYN-ACK + synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + // ACK + ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } +} + +var scenarios = []struct { + name string + stateful bool // Whether conntrack is enabled + rules bool // Whether to add return traffic rules + routed bool // Whether to test routed network traffic + connCount int // Number of concurrent connections + desc string +}{ + { + name: "stateless_with_rules_100conns", + stateful: false, + rules: true, + routed: false, + connCount: 100, + desc: "Pure stateless with return traffic rules, 100 conns", + }, + { + name: "stateless_with_rules_1000conns", + stateful: false, + rules: true, + routed: false, + connCount: 1000, + desc: "Pure stateless with return traffic rules, 1000 conns", + }, + { + name: "stateful_no_rules_100conns", + stateful: true, + rules: false, + routed: false, + connCount: 100, + desc: "Pure stateful tracking without rules, 100 conns", + }, + { + name: "stateful_no_rules_1000conns", + stateful: true, + rules: false, + routed: false, + connCount: 1000, + desc: "Pure stateful tracking without rules, 1000 conns", + }, + { + name: "stateful_with_rules_100conns", + stateful: true, + rules: true, + routed: false, + connCount: 100, + desc: "Combined stateful + rules (current implementation), 100 conns", + }, + { + name: "stateful_with_rules_1000conns", + stateful: true, + rules: true, + routed: false, + connCount: 1000, + desc: "Combined stateful + rules (current implementation), 1000 conns", + }, + { + name: "routed_network_100conns", + stateful: true, + rules: false, + routed: true, + connCount: 100, + desc: "Routed network traffic (non-WG), 100 conns", + }, + { + name: "routed_network_1000conns", + stateful: true, + rules: false, + routed: true, + connCount: 1000, + desc: "Routed network traffic (non-WG), 1000 conns", + }, +} + +// BenchmarkLongLivedConnections tests performance with realistic TCP traffic patterns +func BenchmarkLongLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + // Single rule to allow all return traffic from port 80 + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create established connections + for i := 0; i < sc.connCount; i++ { + // Initial SYN + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + + // SYN-ACK + synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + + // ACK + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + + // Prepare test packets simulating bidirectional traffic + inPackets := make([][]byte, sc.connCount) + outPackets := make([][]byte, sc.connCount) + for i := 0; i < sc.connCount; i++ { + // Server -> Client (inbound) + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + // Client -> Server (outbound) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + connIdx := i % sc.connCount + + // Simulate bidirectional traffic + // First outbound data + manager.processOutgoingHooks(outPackets[connIdx]) + // Then inbound response - this is what we're actually measuring + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + } + }) + } +} + +// BenchmarkShortLivedConnections tests performance with many short-lived connections +func BenchmarkShortLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + // Single rule to allow all return traffic from port 80 + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create packet patterns for a complete HTTP-like short connection: + // 1. Initial handshake (SYN, SYN-ACK, ACK) + // 2. HTTP Request (PSH+ACK from client) + // 3. HTTP Response (PSH+ACK from server) + // 4. Connection teardown (FIN+ACK, ACK, FIN+ACK, ACK) + type connPackets struct { + syn []byte + synAck []byte + ack []byte + request []byte + response []byte + finClient []byte + ackServer []byte + finServer []byte + ackClient []byte + } + + // Generate all possible connection patterns + patterns := make([]connPackets, sc.connCount) + for i := 0; i < sc.connCount; i++ { + patterns[i] = connPackets{ + // Handshake + syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)), + synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)), + ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + + // Data transfer + request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)), + + // Connection teardown + finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPAck)), + finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Each iteration creates a new short-lived connection + connIdx := i % sc.connCount + p := patterns[connIdx] + + // Connection establishment + manager.processOutgoingHooks(p.syn) + manager.dropFilter(p.synAck, manager.incomingRules) + manager.processOutgoingHooks(p.ack) + + // Data transfer + manager.processOutgoingHooks(p.request) + manager.dropFilter(p.response, manager.incomingRules) + + // Connection teardown + manager.processOutgoingHooks(p.finClient) + manager.dropFilter(p.ackServer, manager.incomingRules) + manager.dropFilter(p.finServer, manager.incomingRules) + manager.processOutgoingHooks(p.ackClient) + } + }) + } +} + +// BenchmarkParallelLongLivedConnections tests performance with realistic TCP traffic patterns in parallel +func BenchmarkParallelLongLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create established connections + for i := 0; i < sc.connCount; i++ { + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + + synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + + // Pre-generate test packets + inPackets := make([][]byte, sc.connCount) + outPackets := make([][]byte, sc.connCount) + for i := 0; i < sc.connCount; i++ { + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + // Each goroutine gets its own counter to distribute load + counter := 0 + for pb.Next() { + connIdx := counter % sc.connCount + counter++ + + // Simulate bidirectional traffic + manager.processOutgoingHooks(outPackets[connIdx]) + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + } + }) + }) + } +} + +// BenchmarkParallelShortLivedConnections tests performance with many short-lived connections in parallel +func BenchmarkParallelShortLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + if sc.rules { + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs and pre-generate all packet patterns + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + type connPackets struct { + syn []byte + synAck []byte + ack []byte + request []byte + response []byte + finClient []byte + ackServer []byte + finServer []byte + ackClient []byte + } + + patterns := make([]connPackets, sc.connCount) + for i := 0; i < sc.connCount; i++ { + patterns[i] = connPackets{ + syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)), + synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)), + ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)), + finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPAck)), + finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + } + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + counter := 0 + for pb.Next() { + connIdx := counter % sc.connCount + counter++ + p := patterns[connIdx] + + // Full connection lifecycle + manager.processOutgoingHooks(p.syn) + manager.dropFilter(p.synAck, manager.incomingRules) + manager.processOutgoingHooks(p.ack) + + manager.processOutgoingHooks(p.request) + manager.dropFilter(p.response, manager.incomingRules) + + manager.processOutgoingHooks(p.finClient) + manager.dropFilter(p.ackServer, manager.incomingRules) + manager.dropFilter(p.finServer, manager.incomingRules) + manager.processOutgoingHooks(p.ackClient) + } + }) + }) + } +} + +// generateTCPPacketWithFlags creates a TCP packet with specific flags +func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { + b.Helper() + + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolTCP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + } + + // Set TCP flags + tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0 + tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0 + tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0 + tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 + tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 + + require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index d7c93cb7f..d3563e6f2 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "sync" "testing" "time" @@ -11,6 +12,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -185,10 +187,10 @@ func TestAddUDPPacketHook(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - manager := &Manager{ - incomingRules: map[string]RuleSet{}, - outgoingRules: map[string]RuleSet{}, - } + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -313,7 +315,7 @@ func TestNotMatchByIP(t *testing.T) { t.Errorf("failed to set network layer for checksum: %v", err) return } - payload := gopacket.Payload([]byte("test")) + payload := gopacket.Payload("test") buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -325,7 +327,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if m.dropFilter(buf.Bytes(), m.outgoingRules, false) { + if m.dropFilter(buf.Bytes(), m.outgoingRules) { t.Errorf("expected packet to be accepted") return } @@ -348,6 +350,9 @@ func TestRemovePacketHook(t *testing.T) { if err != nil { t.Fatalf("Failed to create Manager: %s", err) } + defer func() { + require.NoError(t, manager.Reset(nil)) + }() // Add a UDP packet hook hookFunc := func(data []byte) bool { return true } @@ -384,6 +389,88 @@ func TestRemovePacketHook(t *testing.T) { } } +func TestProcessOutgoingHooks(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + manager.udpTracker.Close() + manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) + defer func() { + require.NoError(t, manager.Reset(nil)) + }() + + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + } + + hookCalled := false + hookID := manager.AddUDPPacketHook( + false, + net.ParseIP("100.10.0.100"), + 53, + func([]byte) bool { + hookCalled = true + return true + }, + ) + require.NotEmpty(t, hookID) + + // Create test UDP packet + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: net.ParseIP("100.10.0.1"), + DstIP: net.ParseIP("100.10.0.100"), + Protocol: layers.IPProtocolUDP, + } + udp := &layers.UDP{ + SrcPort: 51334, + DstPort: 53, + } + + err = udp.SetNetworkLayerForChecksum(ipv4) + require.NoError(t, err) + payload := gopacket.Payload("test") + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload) + require.NoError(t, err) + + // Test hook gets called + result := manager.processOutgoingHooks(buf.Bytes()) + require.True(t, result) + require.True(t, hookCalled) + + // Test non-UDP packet is ignored + ipv4.Protocol = layers.IPProtocolTCP + buf = gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(buf, opts, ipv4) + require.NoError(t, err) + + result = manager.processOutgoingHooks(buf.Bytes()) + require.False(t, result) +} + func TestUSPFilterCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { @@ -418,3 +505,213 @@ func TestUSPFilterCreatePerformance(t *testing.T) { }) } } + +func TestStatefulFirewall_UDPTracking(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + + manager.udpTracker.Close() // Close the existing tracker + manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + } + defer func() { + require.NoError(t, manager.Reset(nil)) + }() + + // Set up packet parameters + srcIP := net.ParseIP("100.10.0.1") + dstIP := net.ParseIP("100.10.0.100") + srcPort := uint16(51334) + dstPort := uint16(53) + + // Create outbound packet + outboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolUDP, + } + outboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + + err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4) + require.NoError(t, err) + + outboundBuf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + err = gopacket.SerializeLayers(outboundBuf, opts, + outboundIPv4, + outboundUDP, + gopacket.Payload("test"), + ) + require.NoError(t, err) + + // Process outbound packet and verify connection tracking + drop := manager.DropOutgoing(outboundBuf.Bytes()) + require.False(t, drop, "Initial outbound packet should not be dropped") + + // Verify connection was tracked + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + + require.True(t, exists, "Connection should be tracked after outbound packet") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") + require.Equal(t, srcPort, conn.SourcePort, "Source port should match") + require.Equal(t, dstPort, conn.DestPort, "Destination port should match") + + // Create valid inbound response packet + inboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: dstIP, // Original destination is now source + DstIP: srcIP, // Original source is now destination + Protocol: layers.IPProtocolUDP, + } + inboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(dstPort), // Original destination port is now source + DstPort: layers.UDPPort(srcPort), // Original source port is now destination + } + + err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4) + require.NoError(t, err) + + inboundBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(inboundBuf, opts, + inboundIPv4, + inboundUDP, + gopacket.Payload("response"), + ) + require.NoError(t, err) + // Test roundtrip response handling over time + checkPoints := []struct { + sleep time.Duration + shouldAllow bool + description string + }{ + { + sleep: 0, + shouldAllow: true, + description: "Immediate response should be allowed", + }, + { + sleep: 50 * time.Millisecond, + shouldAllow: true, + description: "Response within timeout should be allowed", + }, + { + sleep: 100 * time.Millisecond, + shouldAllow: true, + description: "Response at half timeout should be allowed", + }, + { + // tracker hasn't updated conn for 250ms -> greater than 200ms timeout + sleep: 250 * time.Millisecond, + shouldAllow: false, + description: "Response after timeout should be dropped", + }, + } + + for _, cp := range checkPoints { + time.Sleep(cp.sleep) + + drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules) + require.Equal(t, cp.shouldAllow, !drop, cp.description) + + // If the connection should still be valid, verify it exists + if cp.shouldAllow { + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + require.True(t, exists, "Connection should still exist during valid window") + require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), + "LastSeen should be updated for valid responses") + } + } + + // Test invalid response packets (while connection is expired) + invalidCases := []struct { + name string + modifyFunc func(*layers.IPv4, *layers.UDP) + description string + }{ + { + name: "wrong source IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.SrcIP = net.ParseIP("100.10.0.101") + }, + description: "Response from wrong IP should be dropped", + }, + { + name: "wrong destination IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.DstIP = net.ParseIP("100.10.0.2") + }, + description: "Response to wrong IP should be dropped", + }, + { + name: "wrong source port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.SrcPort = 54 + }, + description: "Response from wrong port should be dropped", + }, + { + name: "wrong destination port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.DstPort = 51335 + }, + description: "Response to wrong port should be dropped", + }, + } + + // Create a new outbound connection for invalid tests + drop = manager.processOutgoingHooks(outboundBuf.Bytes()) + require.False(t, drop, "Second outbound packet should not be dropped") + + for _, tc := range invalidCases { + t.Run(tc.name, func(t *testing.T) { + testIPv4 := *inboundIPv4 + testUDP := *inboundUDP + + tc.modifyFunc(&testIPv4, &testUDP) + + err = testUDP.SetNetworkLayerForChecksum(&testIPv4) + require.NoError(t, err) + + testBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(testBuf, opts, + &testIPv4, + &testUDP, + gopacket.Payload("response"), + ) + require.NoError(t, err) + + // Verify the invalid packet is dropped + drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules) + require.True(t, drop, tc.description) + }) + } +}