diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 1982e4e7e..4808c3090 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -4,6 +4,7 @@ package uspfilter import ( "context" + "net/netip" "time" log "github.com/sirupsen/logrus" @@ -17,8 +18,8 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - m.outgoingRules = make(map[string]RuleSet) - m.incomingRules = make(map[string]RuleSet) + m.outgoingRules = make(map[netip.Addr]RuleSet) + m.incomingRules = make(map[netip.Addr]RuleSet) if m.udpTracker != nil { m.udpTracker.Close() @@ -35,8 +36,8 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) } - if m.forwarder != nil { - m.forwarder.Stop() + if fwder := m.forwarder.Load(); fwder != nil { + fwder.Stop() } if m.logger != nil { diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index cacabe1b3..ff80fec41 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -3,6 +3,7 @@ package uspfilter import ( "context" "fmt" + "net/netip" "os/exec" "syscall" "time" @@ -26,8 +27,8 @@ func (m *Manager) Reset(*statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - m.outgoingRules = make(map[string]RuleSet) - m.incomingRules = make(map[string]RuleSet) + m.outgoingRules = make(map[netip.Addr]RuleSet) + m.incomingRules = make(map[netip.Addr]RuleSet) if m.udpTracker != nil { m.udpTracker.Close() @@ -44,8 +45,8 @@ func (m *Manager) Reset(*statemanager.Manager) error { m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) } - if m.forwarder != nil { - m.forwarder.Stop() + if fwder := m.forwarder.Load(); fwder != nil { + fwder.Stop() } if m.logger != nil { diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 915b57549..b87ec6fd2 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -2,7 +2,6 @@ package conntrack import ( "fmt" - "net" "net/netip" "sync/atomic" "time" @@ -52,16 +51,3 @@ type ConnKey struct { func (c ConnKey) String() string { return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) } - -// makeConnKey creates a connection key -func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { - srcAddr, _ := netip.AddrFromSlice(srcIP) - dstAddr, _ := netip.AddrFromSlice(dstIP) - - return ConnKey{ - SrcIP: srcAddr, - DstIP: dstAddr, - SrcPort: srcPort, - DstPort: dstPort, - } -} diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index 6d1ed5890..e83b6aa85 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -2,7 +2,7 @@ package conntrack import ( "context" - "net" + "net/netip" "testing" "github.com/sirupsen/logrus" @@ -21,11 +21,11 @@ func BenchmarkMemoryPressure(b *testing.B) { defer tracker.Close() // Generate different IPs - srcIPs := make([]net.IP, 100) - dstIPs := make([]net.IP, 100) + srcIPs := make([]netip.Addr, 100) + dstIPs := make([]netip.Addr, 100) for i := 0; i < 100; i++ { - srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) - dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)}) + dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) } b.ResetTimer() @@ -46,11 +46,11 @@ func BenchmarkMemoryPressure(b *testing.B) { defer tracker.Close() // Generate different IPs - srcIPs := make([]net.IP, 100) - dstIPs := make([]net.IP, 100) + srcIPs := make([]netip.Addr, 100) + dstIPs := make([]netip.Addr, 100) for i := 0; i < 100; i++ { - srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) - dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)}) + dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) } b.ResetTimer() diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index fda08154c..97617a52d 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -2,7 +2,6 @@ package conntrack import ( "fmt" - "net" "net/netip" "sync" "time" @@ -70,8 +69,13 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty return tracker } -func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) { - key := makeICMPKey(srcIP, dstIP, id, seq) +func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16) (ICMPConnKey, bool) { + key := ICMPConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + ID: id, + Sequence: seq, + } t.mutex.RLock() conn, exists := t.connections[key] @@ -87,7 +91,7 @@ func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq } // TrackOutbound records an outbound ICMP connection -func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { +func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists { // if (inverted direction) conn is not tracked, track this direction t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress) @@ -95,12 +99,12 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u } // TrackInbound records an inbound ICMP Echo Request -func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { +func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress) } // track is the common implementation for tracking both inbound and outbound ICMP connections -func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) { +func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) { // TODO: icmp doesn't need to extend the timeout key, exists := t.updateIfExists(srcIP, dstIP, id, seq) if exists { @@ -112,7 +116,7 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t // non echo requests don't need tracking if typ != uint8(layers.ICMPv4TypeEchoRequest) { t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) - t.sendStartEvent(direction, key, typ, code) + t.sendStartEvent(direction, srcIP, dstIP, typ, code) return } @@ -120,8 +124,8 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t BaseConnTrack: BaseConnTrack{ FlowId: uuid.New(), Direction: direction, - SourceIP: key.SrcIP, - DestIP: key.DstIP, + SourceIP: srcIP, + DestIP: dstIP, }, ICMPType: typ, ICMPCode: code, @@ -133,16 +137,21 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t t.mutex.Unlock() t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) - t.sendEvent(nftypes.TypeStart, key, conn) + t.sendEvent(nftypes.TypeStart, conn) } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request -func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { +func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, seq uint16, icmpType uint8) bool { if icmpType != uint8(layers.ICMPv4TypeEchoReply) { return false } - key := makeICMPKey(dstIP, srcIP, id, seq) + key := ICMPConnKey{ + SrcIP: dstIP, + DstIP: srcIP, + ID: id, + Sequence: seq, + } t.mutex.RLock() conn, exists := t.connections[key] @@ -177,7 +186,7 @@ func (t *ICMPTracker) cleanup() { delete(t.connections, key) t.logger.Debug("Removed ICMP connection %s (timeout)", &key) - t.sendEvent(nftypes.TypeEnd, key, conn) + t.sendEvent(nftypes.TypeEnd, conn) } } } @@ -192,40 +201,28 @@ func (t *ICMPTracker) Close() { t.mutex.Unlock() } -func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) { +func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack) { t.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: conn.FlowId, Type: typ, Direction: conn.Direction, Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 - SourceIP: key.SrcIP, - DestIP: key.DstIP, + SourceIP: conn.SourceIP, + DestIP: conn.DestIP, ICMPType: conn.ICMPType, ICMPCode: conn.ICMPCode, }) } -func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, key ICMPConnKey, typ, code uint8) { +func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8) { t.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), Type: nftypes.TypeStart, Direction: direction, Protocol: nftypes.ICMP, - SourceIP: key.SrcIP, - DestIP: key.DstIP, + SourceIP: srcIP, + DestIP: dstIP, ICMPType: typ, ICMPCode: code, }) } - -// makeICMPKey creates an ICMP connection key -func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { - srcAddr, _ := netip.AddrFromSlice(srcIP) - dstAddr, _ := netip.AddrFromSlice(dstIP) - return ICMPConnKey{ - SrcIP: srcAddr, - DstIP: dstAddr, - ID: id, - Sequence: seq, - } -} diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index b8328ae94..259cc21a4 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -1,7 +1,7 @@ package conntrack import ( - "net" + "net/netip" "testing" ) @@ -10,8 +10,8 @@ func BenchmarkICMPTracker(b *testing.B) { tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("192.168.1.1") - dstIP := net.ParseIP("192.168.1.2") + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") b.ResetTimer() for i := 0; i < b.N; i++ { @@ -23,8 +23,8 @@ func BenchmarkICMPTracker(b *testing.B) { tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("192.168.1.1") - dstIP := net.ParseIP("192.168.1.2") + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") // Pre-populate some connections for i := 0; i < 1000; i++ { diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 1e6364f68..0c0e7bd99 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -3,7 +3,7 @@ package conntrack // TODO: Send RST packets for invalid/timed-out connections import ( - "net" + "net/netip" "sync" "sync/atomic" "time" @@ -144,8 +144,13 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) { - key := makeConnKey(srcIP, dstIP, srcPort, dstPort) +func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) { + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } t.mutex.RLock() conn, exists := t.connections[key] @@ -154,7 +159,6 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, if exists { conn.Lock() t.updateState(key, conn, flags, conn.Direction == nftypes.Egress) - conn.UpdateLastSeen() conn.Unlock() return key, true @@ -164,7 +168,7 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, } // TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { +func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists { // if (inverted direction) conn is not tracked, track this direction t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress) @@ -172,12 +176,12 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { +func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) { t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) { +func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) { key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags) if exists { return @@ -187,14 +191,13 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u BaseConnTrack: BaseConnTrack{ FlowId: uuid.New(), Direction: direction, - SourceIP: key.SrcIP, - DestIP: key.DstIP, + SourceIP: srcIP, + DestIP: dstIP, SourcePort: srcPort, DestPort: dstPort, }, } - conn.UpdateLastSeen() conn.established.Store(false) conn.tombstone.Store(false) @@ -205,12 +208,17 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u t.connections[key] = conn t.mutex.Unlock() - t.sendEvent(nftypes.TypeStart, key, conn) + t.sendEvent(nftypes.TypeStart, conn) } // IsValidInbound checks if an inbound TCP packet matches a tracked connection -func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { - key := makeConnKey(dstIP, srcIP, dstPort, srcPort) +func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8) bool { + key := ConnKey{ + SrcIP: dstIP, + DstIP: srcIP, + SrcPort: dstPort, + DstPort: srcPort, + } t.mutex.RLock() conn, exists := t.connections[key] @@ -233,13 +241,12 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, conn.Unlock() t.logger.Trace("TCP connection reset: %s", key) - t.sendEvent(nftypes.TypeEnd, key, conn) + t.sendEvent(nftypes.TypeEnd, conn) return true } conn.Lock() t.updateState(key, conn, flags, false) - conn.UpdateLastSeen() isEstablished := conn.IsEstablished() isValidState := t.isValidStateForFlags(conn.State, flags) conn.Unlock() @@ -249,6 +256,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, // updateState updates the TCP connection state based on flags func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) { + conn.UpdateLastSeen() + state := conn.State defer func() { if state != conn.State { @@ -312,7 +321,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i conn.State = TCPStateTimeWait t.logger.Trace("TCP connection %s completed", key) - t.sendEvent(nftypes.TypeEnd, key, conn) + t.sendEvent(nftypes.TypeEnd, conn) } case TCPStateClosing: @@ -321,7 +330,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i // Keep established = false from previous state t.logger.Trace("TCP connection %s closed (simultaneous)", key) - t.sendEvent(nftypes.TypeEnd, key, conn) + t.sendEvent(nftypes.TypeEnd, conn) } case TCPStateCloseWait: @@ -335,7 +344,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i conn.SetTombstone() // Send close event for gracefully closed connections - t.sendEvent(nftypes.TypeEnd, key, conn) + t.sendEvent(nftypes.TypeEnd, conn) t.logger.Trace("TCP connection %s closed gracefully", key) } } @@ -422,7 +431,7 @@ func (t *TCPTracker) cleanup() { // event already handled by state change if conn.State != TCPStateTimeWait { - t.sendEvent(nftypes.TypeEnd, key, conn) + t.sendEvent(nftypes.TypeEnd, conn) } } } @@ -453,15 +462,15 @@ func isValidFlagCombination(flags uint8) bool { return true } -func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) { +func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack) { t.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: conn.FlowId, Type: typ, Direction: conn.Direction, Protocol: nftypes.TCP, - SourceIP: key.SrcIP, - DestIP: key.DstIP, - SourcePort: key.SrcPort, - DestPort: key.DstPort, + SourceIP: conn.SourceIP, + DestIP: conn.DestIP, + SourcePort: conn.SourcePort, + DestPort: conn.DestPort, }) } diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 200d77501..122deae1e 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -1,7 +1,7 @@ package conntrack import ( - "net" + "net/netip" "testing" "time" @@ -12,8 +12,8 @@ func TestTCPStateMachine(t *testing.T) { tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("100.64.0.1") - dstIP := net.ParseIP("100.64.0.2") + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") srcPort := uint16(12345) dstPort := uint16(80) @@ -165,8 +165,8 @@ func TestRSTHandling(t *testing.T) { tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("100.64.0.1") - dstIP := net.ParseIP("100.64.0.2") + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") srcPort := uint16(12345) dstPort := uint16(80) @@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) { tt.sendRST() // Verify connection state is as expected - key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } conn := tracker.connections[key] if tt.wantValid { require.NotNil(t, conn) @@ -220,7 +225,7 @@ func TestRSTHandling(t *testing.T) { } // Helper to establish a TCP connection -func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { +func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { t.Helper() tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) @@ -236,8 +241,8 @@ func BenchmarkTCPTracker(b *testing.B) { tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("192.168.1.1") - dstIP := net.ParseIP("192.168.1.2") + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") b.ResetTimer() for i := 0; i < b.N; i++ { @@ -249,8 +254,8 @@ func BenchmarkTCPTracker(b *testing.B) { tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("192.168.1.1") - dstIP := net.ParseIP("192.168.1.2") + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") // Pre-populate some connections for i := 0; i < 1000; i++ { @@ -267,8 +272,8 @@ func BenchmarkTCPTracker(b *testing.B) { tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("192.168.1.1") - dstIP := net.ParseIP("192.168.1.2") + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") b.RunParallel(func(pb *testing.PB) { i := 0 @@ -291,8 +296,8 @@ func BenchmarkCleanup(b *testing.B) { defer tracker.Close() // Pre-populate with expired connections - srcIP := net.ParseIP("192.168.1.1") - dstIP := net.ParseIP("192.168.1.2") + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") for i := 0; i < 10000; i++ { tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) } diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 922db371d..465e400be 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -1,7 +1,7 @@ package conntrack import ( - "net" + "net/netip" "sync" "time" @@ -54,7 +54,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp } // TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { +func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists { // if (inverted direction) conn is not tracked, track this direction t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress) @@ -62,12 +62,17 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) { t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress) } -func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) { - key := makeConnKey(srcIP, dstIP, srcPort, dstPort) +func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) (ConnKey, bool) { + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } t.mutex.RLock() conn, exists := t.connections[key] @@ -82,7 +87,7 @@ func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) { +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction) { key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort) if exists { return @@ -92,8 +97,8 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u BaseConnTrack: BaseConnTrack{ FlowId: uuid.New(), Direction: direction, - SourceIP: key.SrcIP, - DestIP: key.DstIP, + SourceIP: srcIP, + DestIP: dstIP, SourcePort: srcPort, DestPort: dstPort, }, @@ -105,12 +110,17 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u t.mutex.Unlock() t.logger.Trace("New %s UDP connection: %s", direction, key) - t.sendEvent(nftypes.TypeStart, key, conn) + t.sendEvent(nftypes.TypeStart, conn) } // IsValidInbound checks if an inbound packet matches a tracked connection -func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { - key := makeConnKey(dstIP, srcIP, dstPort, srcPort) +func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16) bool { + key := ConnKey{ + SrcIP: dstIP, + DstIP: srcIP, + SrcPort: dstPort, + DstPort: srcPort, + } t.mutex.RLock() conn, exists := t.connections[key] @@ -146,7 +156,7 @@ func (t *UDPTracker) cleanup() { delete(t.connections, key) t.logger.Trace("Removed UDP connection %s (timeout)", key) - t.sendEvent(nftypes.TypeEnd, key, conn) + t.sendEvent(nftypes.TypeEnd, conn) } } } @@ -162,11 +172,16 @@ func (t *UDPTracker) Close() { } // GetConnection safely retrieves a connection state -func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { +func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) { t.mutex.RLock() defer t.mutex.RUnlock() - key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } conn, exists := t.connections[key] return conn, exists } @@ -176,15 +191,15 @@ func (t *UDPTracker) Timeout() time.Duration { return t.timeout } -func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) { +func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack) { t.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: conn.FlowId, Type: typ, Direction: conn.Direction, Protocol: nftypes.UDP, - SourceIP: key.SrcIP, - DestIP: key.DstIP, - SourcePort: key.SrcPort, - DestPort: key.DstPort, + SourceIP: conn.SourceIP, + DestIP: conn.DestIP, + SourcePort: conn.SourcePort, + DestPort: conn.DestPort, }) } diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 29c7111fd..db7fa0f51 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -1,7 +1,6 @@ package conntrack import ( - "net" "net/netip" "testing" "time" @@ -49,10 +48,15 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { srcPort := uint16(12345) dstPort := uint16(53) - tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) // Verify connection was tracked - key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort) + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } conn, exists := tracker.connections[key] require.True(t, exists) assert.True(t, conn.SourceIP.Compare(srcIP) == 0) @@ -66,8 +70,8 @@ func TestUDPTracker_IsValidInbound(t *testing.T) { tracker := NewUDPTracker(1*time.Second, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("192.168.1.2") - dstIP := net.ParseIP("192.168.1.3") + srcIP := netip.MustParseAddr("192.168.1.2") + dstIP := netip.MustParseAddr("192.168.1.3") srcPort := uint16(12345) dstPort := uint16(53) @@ -76,8 +80,8 @@ func TestUDPTracker_IsValidInbound(t *testing.T) { tests := []struct { name string - srcIP net.IP - dstIP net.IP + srcIP netip.Addr + dstIP netip.Addr srcPort uint16 dstPort uint16 sleep time.Duration @@ -94,7 +98,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) { }, { name: "invalid source IP", - srcIP: net.ParseIP("192.168.1.4"), + srcIP: netip.MustParseAddr("192.168.1.4"), dstIP: srcIP, srcPort: dstPort, dstPort: srcPort, @@ -104,7 +108,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) { { name: "invalid destination IP", srcIP: dstIP, - dstIP: net.ParseIP("192.168.1.4"), + dstIP: netip.MustParseAddr("192.168.1.4"), srcPort: dstPort, dstPort: srcPort, sleep: 0, @@ -170,20 +174,20 @@ func TestUDPTracker_Cleanup(t *testing.T) { // Add some connections connections := []struct { - srcIP net.IP - dstIP net.IP + srcIP netip.Addr + dstIP netip.Addr srcPort uint16 dstPort uint16 }{ { - srcIP: net.ParseIP("192.168.1.2"), - dstIP: net.ParseIP("192.168.1.3"), + srcIP: netip.MustParseAddr("192.168.1.2"), + dstIP: netip.MustParseAddr("192.168.1.3"), srcPort: 12345, dstPort: 53, }, { - srcIP: net.ParseIP("192.168.1.4"), - dstIP: net.ParseIP("192.168.1.5"), + srcIP: netip.MustParseAddr("192.168.1.4"), + dstIP: netip.MustParseAddr("192.168.1.5"), srcPort: 12346, dstPort: 53, }, @@ -215,8 +219,8 @@ func BenchmarkUDPTracker(b *testing.B) { tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("192.168.1.1") - dstIP := net.ParseIP("192.168.1.2") + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") b.ResetTimer() for i := 0; i < b.N; i++ { @@ -228,8 +232,8 @@ func BenchmarkUDPTracker(b *testing.B) { tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) defer tracker.Close() - srcIP := net.ParseIP("192.168.1.1") - dstIP := net.ParseIP("192.168.1.2") + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") // Pre-populate some connections for i := 0; i < 1000; i++ { diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go index 7664b65d5..b86d16043 100644 --- a/client/firewall/uspfilter/localip.go +++ b/client/firewall/uspfilter/localip.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "net/netip" "sync" log "github.com/sirupsen/logrus" @@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) { m.ipv4Bitmap[high] |= 1 << (low % 32) } -func (m *localIPManager) checkBitmapBit(ip net.IP) bool { - ipv4 := ip.To4() - if ipv4 == nil { - return false - } - high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) - low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) +func (m *localIPManager) checkBitmapBit(ip []byte) bool { + high := (uint16(ip[0]) << 8) | uint16(ip[1]) + low := (uint16(ip[2]) << 8) | uint16(ip[3]) return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0 } @@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { return nil } -func (m *localIPManager) IsLocalIP(ip net.IP) bool { +func (m *localIPManager) IsLocalIP(ip netip.Addr) bool { m.mu.RLock() defer m.mu.RUnlock() - if ipv4 := ip.To4(); ipv4 != nil { - return m.checkBitmapBit(ipv4) + if ip.Is4() { + return m.checkBitmapBit(ip.AsSlice()) } return false diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go index 02f41bf4f..890b7a30d 100644 --- a/client/firewall/uspfilter/localip_test.go +++ b/client/firewall/uspfilter/localip_test.go @@ -2,6 +2,7 @@ package uspfilter import ( "net" + "net/netip" "testing" "github.com/stretchr/testify/require" @@ -13,7 +14,7 @@ func TestLocalIPManager(t *testing.T) { tests := []struct { name string setupAddr iface.WGAddress - testIP net.IP + testIP netip.Addr expected bool }{ { @@ -25,7 +26,7 @@ func TestLocalIPManager(t *testing.T) { Mask: net.CIDRMask(24, 32), }, }, - testIP: net.ParseIP("127.0.0.2"), + testIP: netip.MustParseAddr("127.0.0.2"), expected: true, }, { @@ -37,7 +38,7 @@ func TestLocalIPManager(t *testing.T) { Mask: net.CIDRMask(24, 32), }, }, - testIP: net.ParseIP("127.0.0.1"), + testIP: netip.MustParseAddr("127.0.0.1"), expected: true, }, { @@ -49,7 +50,7 @@ func TestLocalIPManager(t *testing.T) { Mask: net.CIDRMask(24, 32), }, }, - testIP: net.ParseIP("127.255.255.255"), + testIP: netip.MustParseAddr("127.255.255.255"), expected: true, }, { @@ -61,7 +62,7 @@ func TestLocalIPManager(t *testing.T) { Mask: net.CIDRMask(24, 32), }, }, - testIP: net.ParseIP("192.168.1.1"), + testIP: netip.MustParseAddr("192.168.1.1"), expected: true, }, { @@ -73,7 +74,7 @@ func TestLocalIPManager(t *testing.T) { Mask: net.CIDRMask(24, 32), }, }, - testIP: net.ParseIP("192.168.1.2"), + testIP: netip.MustParseAddr("192.168.1.2"), expected: false, }, { @@ -85,7 +86,7 @@ func TestLocalIPManager(t *testing.T) { Mask: net.CIDRMask(64, 128), }, }, - testIP: net.ParseIP("fe80::1"), + testIP: netip.MustParseAddr("fe80::1"), expected: false, }, } @@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) { t.Logf("Testing %d IPs", len(tests)) for _, tt := range tests { t.Run(tt.ip, func(t *testing.T) { - result := manager.IsLocalIP(net.ParseIP(tt.ip)) + result := manager.IsLocalIP(netip.MustParseAddr(tt.ip)) require.Equal(t, tt.expected, result, "IP: %s", tt.ip) }) } diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index 7a587c832..a23d2011b 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -1,7 +1,6 @@ package uspfilter import ( - "net" "net/netip" "github.com/google/gopacket" @@ -13,7 +12,7 @@ import ( type PeerRule struct { id string mgmtId []byte - ip net.IP + ip netip.Addr ipLayer gopacket.LayerType matchByIP bool protoLayer gopacket.LayerType diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index aff886b58..a980cda29 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -2,7 +2,7 @@ package uspfilter import ( "fmt" - "net" + "net/netip" "time" "github.com/google/gopacket" @@ -53,8 +53,8 @@ type TraceResult struct { } type PacketTrace struct { - SourceIP net.IP - DestinationIP net.IP + SourceIP netip.Addr + DestinationIP netip.Addr Protocol string SourcePort uint16 DestinationPort uint16 @@ -72,8 +72,8 @@ type TCPState struct { } type PacketBuilder struct { - SrcIP net.IP - DstIP net.IP + SrcIP netip.Addr + DstIP netip.Addr Protocol fw.Protocol SrcPort uint16 DstPort uint16 @@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 { Version: 4, TTL: 64, Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), - SrcIP: p.SrcIP, - DstIP: p.DstIP, + SrcIP: p.SrcIP.AsSlice(), + DstIP: p.DstIP.AsSlice(), } } @@ -260,7 +260,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa return m.traceInbound(packetData, trace, d, srcIP, dstIP) } -func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace { +func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { return trace } @@ -273,14 +273,14 @@ func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder return trace } - if m.nativeRouter { + if m.nativeRouter.Load() { return m.handleNativeRouter(trace) } return m.handleRouteACLs(trace, d, srcIP, dstIP) } -func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool { +func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool { allowed := m.isValidTrackedConnection(d, srcIP, dstIP) msg := "No existing connection found" if allowed { @@ -309,13 +309,12 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string { return msg } -func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool { +func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { if !m.localForwarding { trace.AddResult(StageRouting, "Local forwarding disabled", false) trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false) return true } - trace.AddResult(StageRouting, "Packet destined for local delivery", true) ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) @@ -341,7 +340,7 @@ func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d * } func (m *Manager) handleRouting(trace *PacketTrace) bool { - if !m.routingEnabled { + if !m.routingEnabled.Load() { trace.AddResult(StageRouting, "Routing disabled", false) trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false) return false @@ -357,7 +356,7 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace { return trace } -func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace { +func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace { proto, _ := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) @@ -373,7 +372,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n } trace.AddResult(StageRouteACL, msg, allowed) - if allowed && m.forwarder != nil { + if allowed && m.forwarder.Load() != nil { m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) } diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go new file mode 100644 index 000000000..e69de29bb diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 01a3976b4..05dd9ff06 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -66,9 +67,9 @@ func (r RouteRules) Sort() { // Manager userspace firewall manager type Manager struct { // outgoingRules is used for hooks only - outgoingRules map[string]RuleSet + outgoingRules map[netip.Addr]RuleSet // incomingRules is used for filtering and hooks - incomingRules map[string]RuleSet + incomingRules map[netip.Addr]RuleSet routeRules RouteRules wgNetwork *net.IPNet decoders sync.Pool @@ -80,9 +81,9 @@ type Manager struct { // indicates whether server routes are disabled disableServerRoutes bool // indicates whether we forward packets not destined for ourselves - routingEnabled bool + routingEnabled atomic.Bool // indicates whether we leave forwarding and filtering to the native firewall - nativeRouter bool + nativeRouter atomic.Bool // indicates whether we track outbound connections stateful bool // indicates whether wireguards runs in netstack mode @@ -95,7 +96,7 @@ type Manager struct { udpTracker *conntrack.UDPTracker icmpTracker *conntrack.ICMPTracker tcpTracker *conntrack.TCPTracker - forwarder *forwarder.Forwarder + forwarder atomic.Pointer[forwarder.Forwarder] logger *nblog.Logger flowLogger nftypes.FlowLogger } @@ -168,18 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe }, }, nativeFirewall: nativeFirewall, - outgoingRules: make(map[string]RuleSet), - incomingRules: make(map[string]RuleSet), + outgoingRules: make(map[netip.Addr]RuleSet), + incomingRules: make(map[netip.Addr]RuleSet), wgIface: iface, localipmanager: newLocalIPManager(), disableServerRoutes: disableServerRoutes, - routingEnabled: false, stateful: !disableConntrack, logger: nblog.NewFromLogrus(log.StandardLogger()), flowLogger: flowLogger, netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, } + m.routingEnabled.Store(false) if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { return nil, fmt.Errorf("update local IPs: %w", err) @@ -211,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe } func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { - if m.forwarder == nil { + if m.forwarder.Load() == nil { return nil } wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) @@ -255,20 +256,20 @@ func (m *Manager) determineRouting() error { switch { case disableUspRouting: - m.routingEnabled = false - m.nativeRouter = false + m.routingEnabled.Store(false) + m.nativeRouter.Store(false) log.Info("userspace routing is disabled") case m.disableServerRoutes: // if server routes are disabled we will let packets pass to the native stack - m.routingEnabled = true - m.nativeRouter = true + m.routingEnabled.Store(true) + m.nativeRouter.Store(true) log.Info("server routes are disabled") case forceUserspaceRouter: - m.routingEnabled = true - m.nativeRouter = false + m.routingEnabled.Store(true) + m.nativeRouter.Store(false) log.Info("userspace routing is forced") @@ -276,19 +277,19 @@ func (m *Manager) determineRouting() error { // if the OS supports routing natively, then we don't need to filter/route ourselves // netstack mode won't support native routing as there is no interface - m.routingEnabled = true - m.nativeRouter = true + m.routingEnabled.Store(true) + m.nativeRouter.Store(true) log.Info("native routing is enabled") default: - m.routingEnabled = true - m.nativeRouter = false + m.routingEnabled.Store(true) + m.nativeRouter.Store(false) log.Info("userspace routing enabled by default") } - if m.routingEnabled && !m.nativeRouter { + if m.routingEnabled.Load() && !m.nativeRouter.Load() { return m.initForwarder() } @@ -297,24 +298,24 @@ func (m *Manager) determineRouting() error { // initForwarder initializes the forwarder, it disables routing on errors func (m *Manager) initForwarder() error { - if m.forwarder != nil { + if m.forwarder.Load() != nil { return nil } // Only supported in userspace mode as we need to inject packets back into wireguard directly intf := m.wgIface.GetWGDevice() if intf == nil { - m.routingEnabled = false + m.routingEnabled.Store(false) return errors.New("forwarding not supported") } forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) if err != nil { - m.routingEnabled = false + m.routingEnabled.Store(false) return fmt.Errorf("create forwarder: %w", err) } - m.forwarder = forwarder + m.forwarder.Store(forwarder) log.Debug("forwarder initialized") @@ -330,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool { } func (m *Manager) AddNatRule(pair firewall.RouterPair) error { - if m.nativeRouter && m.nativeFirewall != nil { + if m.nativeRouter.Load() && m.nativeFirewall != nil { return m.nativeFirewall.AddNatRule(pair) } @@ -341,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { // RemoveNatRule removes a routing firewall rule func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { - if m.nativeRouter && m.nativeFirewall != nil { + if m.nativeRouter.Load() && m.nativeFirewall != nil { return m.nativeFirewall.RemoveNatRule(pair) } return nil @@ -360,17 +361,23 @@ func (m *Manager) AddPeerFiltering( action firewall.Action, _ string, ) ([]firewall.Rule, error) { + // TODO: fix in upper layers + i, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("invalid IP: %s", ip) + } + + i = i.Unmap() r := PeerRule{ id: uuid.New().String(), mgmtId: id, - ip: ip, + ip: i, ipLayer: layers.LayerTypeIPv6, matchByIP: true, drop: action == firewall.ActionDrop, } - if ipNormalized := ip.To4(); ipNormalized != nil { + if i.Is4() { r.ipLayer = layers.LayerTypeIPv4 - r.ip = ipNormalized } if s := r.ip.String(); s == "0.0.0.0" || s == "::" { @@ -395,10 +402,10 @@ func (m *Manager) AddPeerFiltering( } m.mutex.Lock() - if _, ok := m.incomingRules[r.ip.String()]; !ok { - m.incomingRules[r.ip.String()] = make(RuleSet) + if _, ok := m.incomingRules[r.ip]; !ok { + m.incomingRules[r.ip] = make(RuleSet) } - m.incomingRules[r.ip.String()][r.id] = r + m.incomingRules[r.ip][r.id] = r m.mutex.Unlock() return []firewall.Rule{&r}, nil } @@ -412,13 +419,10 @@ func (m *Manager) AddRouteFiltering( dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { - if m.nativeRouter && m.nativeFirewall != nil { + if m.nativeRouter.Load() && m.nativeFirewall != nil { return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } - m.mutex.Lock() - defer m.mutex.Unlock() - ruleID := uuid.New().String() rule := RouteRule{ // TODO: consolidate these IDs @@ -432,14 +436,16 @@ func (m *Manager) AddRouteFiltering( action: action, } + m.mutex.Lock() m.routeRules = append(m.routeRules, rule) m.routeRules.Sort() + m.mutex.Unlock() return &rule, nil } func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { - if m.nativeRouter && m.nativeFirewall != nil { + if m.nativeRouter.Load() && m.nativeFirewall != nil { return m.nativeFirewall.DeleteRouteRule(rule) } @@ -468,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { return fmt.Errorf("delete rule: invalid rule type: %T", rule) } - if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok { + if _, ok := m.incomingRules[r.ip][r.id]; !ok { return fmt.Errorf("delete rule: no rule with such id: %v", r.id) } - delete(m.incomingRules[r.ip.String()], r.id) + delete(m.incomingRules[r.ip], r.id) return nil } @@ -519,9 +525,6 @@ func (m *Manager) UpdateLocalIPs() error { } func (m *Manager) processOutgoingHooks(packetData []byte) bool { - m.mutex.RLock() - defer m.mutex.RUnlock() - d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) @@ -534,7 +537,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { } srcIP, dstIP := m.extractIPs(d) - if srcIP == nil { + if !srcIP.IsValid() { + m.logger.Error("Unknown network layer: %v", d.decoded[0]) return false } @@ -551,14 +555,18 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } -func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { +func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) { switch d.decoded[0] { case layers.LayerTypeIPv4: - return d.ip4.SrcIP, d.ip4.DstIP + src, _ := netip.AddrFromSlice(d.ip4.SrcIP) + dst, _ := netip.AddrFromSlice(d.ip4.DstIP) + return src, dst case layers.LayerTypeIPv6: - return d.ip6.SrcIP, d.ip6.DstIP + src, _ := netip.AddrFromSlice(d.ip6.SrcIP) + dst, _ := netip.AddrFromSlice(d.ip6.DstIP) + return src, dst default: - return nil, nil + return netip.Addr{}, netip.Addr{} } } @@ -585,7 +593,7 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) { +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: @@ -598,7 +606,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) { } } -func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) { +func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: @@ -611,8 +619,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) { } } -func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { - for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { +func (m *Manager) checkUDPHooks(d *decoder, dstIP netip.Addr, packetData []byte) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + for _, ipKey := range []netip.Addr{dstIP, netip.IPv4Unspecified(), netip.IPv6Unspecified()} { if rules, exists := m.outgoingRules[ipKey]; exists { for _, rule := range rules { if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) { @@ -627,9 +638,6 @@ func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) boo // dropFilter implements filtering logic for incoming packets. // If it returns true, the packet should be dropped. func (m *Manager) dropFilter(packetData []byte) bool { - m.mutex.RLock() - defer m.mutex.RUnlock() - d := m.decoders.Get().(*decoder) defer m.decoders.Put(d) @@ -638,7 +646,7 @@ func (m *Manager) dropFilter(packetData []byte) bool { } srcIP, dstIP := m.extractIPs(d) - if srcIP == nil { + if !srcIP.IsValid() { m.logger.Error("Unknown network layer: %v", d.decoded[0]) return true } @@ -658,15 +666,13 @@ func (m *Manager) dropFilter(packetData []byte) bool { // handleLocalTraffic handles local traffic. // If it returns true, the packet should be dropped. -func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { +func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool { if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked { - srcAddr, _ := netip.AddrFromSlice(srcIP) - dstAddr, _ := netip.AddrFromSlice(dstIP) _, pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - ruleId, pnum, srcAddr, srcPort, dstAddr, dstPort) + ruleId, pnum, srcIP, srcPort, dstIP, dstPort) m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), @@ -674,8 +680,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData RuleID: ruleId, Direction: nftypes.Ingress, Protocol: pnum, - SourceIP: srcAddr, - DestIP: dstAddr, + SourceIP: srcIP, + DestIP: dstIP, SourcePort: srcPort, DestPort: dstPort, // TODO: icmp type/code @@ -700,12 +706,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool { return false } - if m.forwarder == nil { + if m.forwarder.Load() == nil { m.logger.Trace("Dropping local packet (forwarder not initialized)") return true } - if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { + if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil { m.logger.Error("Failed to inject local packet: %v", err) } @@ -715,16 +721,16 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool { // handleRoutedTraffic handles routed traffic. // If it returns true, the packet should be dropped. -func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { +func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool { // Drop if routing is disabled - if !m.routingEnabled { + if !m.routingEnabled.Load() { m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", srcIP, dstIP) return true } // Pass to native stack if native router is enabled or forced - if m.nativeRouter { + if m.nativeRouter.Load() { return false } @@ -732,9 +738,6 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat srcPort, dstPort := getPortsFromPacket(d) if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { - srcAddr, _ := netip.AddrFromSlice(srcIP) - dstAddr, _ := netip.AddrFromSlice(dstIP) - m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", id, pnum, srcIP, srcPort, dstIP, dstPort) @@ -744,8 +747,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat RuleID: id, Direction: nftypes.Ingress, Protocol: pnum, - SourceIP: srcAddr, - DestIP: dstAddr, + SourceIP: srcIP, + DestIP: dstIP, SourcePort: srcPort, DestPort: dstPort, // TODO: icmp type/code @@ -754,7 +757,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat } // Let forwarder handle the packet if it passed route ACLs - if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { + if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil { m.logger.Error("Failed to inject incoming packet: %v", err) } @@ -799,7 +802,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { return true } -func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { +func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr) bool { switch d.decoded[1] { case layers.LayerTypeTCP: return m.tcpTracker.IsValidInbound( @@ -844,20 +847,22 @@ func (m *Manager) isSpecialICMP(d *decoder) bool { icmpType == layers.ICMPv4TypeTimeExceeded } -func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) ([]byte, bool) { +func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() if m.isSpecialICMP(d) { return nil, false } - if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { + if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok { return mgmtId, filter } - if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { + if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok { return mgmtId, filter } - if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok { + if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok { return mgmtId, filter } @@ -882,10 +887,10 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool { return false } -func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) { +func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) { payloadLayer := d.decoded[1] for _, rule := range rules { - if rule.matchByIP && !ip.Equal(rule.ip) { + if rule.matchByIP && ip.Compare(rule.ip) != 0 { continue } @@ -919,16 +924,13 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de return nil, false, false } -// routeACLsPass returns treu if the packet is allowed by the route ACLs -func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) { +// routeACLsPass returns true if the packet is allowed by the route ACLs +func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) { m.mutex.RLock() defer m.mutex.RUnlock() - srcAddr := netip.AddrFrom4([4]byte(srcIP.To4())) - dstAddr := netip.AddrFrom4([4]byte(dstIP.To4())) - for _, rule := range m.routeRules { - if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) { + if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches { return rule.mgmtId, rule.action == firewall.ActionAccept } } @@ -972,9 +974,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) { // AddUDPPacketHook calls hook when UDP packet from given direction matched // // Hook function returns flag which indicates should be the matched package dropped or not -func (m *Manager) AddUDPPacketHook( - in bool, ip net.IP, dPort uint16, hook func([]byte) bool, -) string { +func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string { r := PeerRule{ id: uuid.New().String(), ip: ip, @@ -984,23 +984,22 @@ func (m *Manager) AddUDPPacketHook( udpHook: hook, } - if ip.To4() != nil { + if ip.Is4() { r.ipLayer = layers.LayerTypeIPv4 } m.mutex.Lock() if in { - if _, ok := m.incomingRules[r.ip.String()]; !ok { - m.incomingRules[r.ip.String()] = make(map[string]PeerRule) + if _, ok := m.incomingRules[r.ip]; !ok { + m.incomingRules[r.ip] = make(map[string]PeerRule) } - m.incomingRules[r.ip.String()][r.id] = r + m.incomingRules[r.ip][r.id] = r } else { - if _, ok := m.outgoingRules[r.ip.String()]; !ok { - m.outgoingRules[r.ip.String()] = make(map[string]PeerRule) + if _, ok := m.outgoingRules[r.ip]; !ok { + m.outgoingRules[r.ip] = make(map[string]PeerRule) } - m.outgoingRules[r.ip.String()][r.id] = r + m.outgoingRules[r.ip][r.id] = r } - m.mutex.Unlock() return r.id @@ -1048,20 +1047,21 @@ func (m *Manager) DisableRouting() error { m.mutex.Lock() defer m.mutex.Unlock() - if m.forwarder == nil { + fwder := m.forwarder.Load() + if fwder == nil { return nil } - m.routingEnabled = false - m.nativeRouter = false + m.routingEnabled.Store(false) + m.nativeRouter.Store(false) // don't stop forwarder if in use by netstack if m.netstack && m.localForwarding { return nil } - m.forwarder.Stop() - m.forwarder = nil + fwder.Stop() + m.forwarder.Store(nil) log.Debug("forwarder stopped") diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index ea9a4285a..9b06b2803 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -1054,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { for _, tc := range cases { - srcIP := net.ParseIP(tc.srcIP) - dstIP := net.ParseIP(tc.dstIP) + srcIP := netip.MustParseAddr(tc.srcIP) + dstIP := netip.MustParseAddr(tc.dstIP) manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) } } diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index d3dbef126..7005d501c 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -306,8 +306,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { require.NoError(tb, manager.EnableRouting()) require.NoError(tb, err) require.NotNil(tb, manager) - require.True(tb, manager.routingEnabled) - require.False(tb, manager.nativeRouter) + require.True(tb, manager.routingEnabled.Load()) + require.False(tb, manager.nativeRouter.Load()) tb.Cleanup(func() { require.NoError(tb, manager.Reset(nil)) @@ -818,8 +818,8 @@ func TestRouteACLFiltering(t *testing.T) { require.NoError(t, manager.DeleteRouteRule(rule)) }) - srcIP := net.ParseIP(tc.srcIP) - dstIP := net.ParseIP(tc.dstIP) + srcIP := netip.MustParseAddr(tc.srcIP) + dstIP := netip.MustParseAddr(tc.dstIP) // testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed // to the forwarder @@ -1006,8 +1006,8 @@ func TestRouteACLOrder(t *testing.T) { }) for i, p := range tc.packets { - srcIP := net.ParseIP(p.srcIP) - dstIP := net.ParseIP(p.dstIP) + srcIP := netip.MustParseAddr(p.srcIP) + dstIP := netip.MustParseAddr(p.dstIP) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index e0e2b86c7..429265794 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -125,19 +125,19 @@ func TestManagerDeleteRule(t *testing.T) { return } - ip := net.ParseIP("192.168.1.1") + ip := netip.MustParseAddr("192.168.1.1") proto := fw.ProtocolTCP port := &fw.Port{Values: []uint16{80}} action := fw.ActionDrop - rule2, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") + rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "") if err != nil { t.Errorf("failed to add filtering: %v", err) return } for _, r := range rule2 { - if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok { + if _, ok := m.incomingRules[ip][r.ID()]; !ok { t.Errorf("rule2 is not in the incomingRules") } } @@ -151,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) { } for _, r := range rule2 { - if _, ok := m.incomingRules[ip.String()][r.ID()]; ok { + if _, ok := m.incomingRules[ip][r.ID()]; ok { t.Errorf("rule2 is not in the incomingRules") } } @@ -162,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) { name string in bool expDir fw.RuleDirection - ip net.IP + ip netip.Addr dPort uint16 hook func([]byte) bool expectedID string @@ -171,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) { name: "Test Outgoing UDP Packet Hook", in: false, expDir: fw.RuleDirectionOUT, - ip: net.IPv4(10, 168, 0, 1), + ip: netip.MustParseAddr("10.168.0.1"), dPort: 8000, hook: func([]byte) bool { return true }, }, @@ -179,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) { name: "Test Incoming UDP Packet Hook", in: true, expDir: fw.RuleDirectionIN, - ip: net.IPv6loopback, + ip: netip.MustParseAddr("::1"), dPort: 9000, hook: func([]byte) bool { return false }, }, @@ -196,11 +196,11 @@ func TestAddUDPPacketHook(t *testing.T) { var addedRule PeerRule if tt.in { - if len(manager.incomingRules[tt.ip.String()]) != 1 { + if len(manager.incomingRules[tt.ip]) != 1 { t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) return } - for _, rule := range manager.incomingRules[tt.ip.String()] { + for _, rule := range manager.incomingRules[tt.ip] { addedRule = rule } } else { @@ -208,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) { t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) return } - for _, rule := range manager.outgoingRules[tt.ip.String()] { + for _, rule := range manager.outgoingRules[tt.ip] { addedRule = rule } } - if !tt.ip.Equal(addedRule.ip) { + if tt.ip.Compare(addedRule.ip) != 0 { t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) return } @@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) { // Add a UDP packet hook hookFunc := func(data []byte) bool { return true } - hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc) + hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc) // Assert the hook is added by finding it in the manager's outgoing rules found := false @@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) { hookCalled := false hookID := manager.AddUDPPacketHook( false, - net.ParseIP("100.10.0.100"), + netip.MustParseAddr("100.10.0.100"), 53, func([]byte) bool { hookCalled = true @@ -573,7 +573,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { require.False(t, drop, "Initial outbound packet should not be dropped") // Verify connection was tracked - conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort) + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) require.True(t, exists, "Connection should be tracked after outbound packet") require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match") @@ -641,7 +641,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { // If the connection should still be valid, verify it exists if cp.shouldAllow { - conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort) + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) require.True(t, exists, "Connection should still exist during valid window") require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), "LastSeen should be updated for valid responses") diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index f87f10429..f21804683 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -2,6 +2,7 @@ package device import ( "net" + "net/netip" "sync" "golang.zx2c4.com/wireguard/tun" @@ -19,7 +20,7 @@ type PacketFilter interface { // // Hook function returns flag which indicates should be the matched package dropped or not. // Hook function receives raw network packet data as argument. - AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string + AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string // RemovePacketHook removes hook by ID RemovePacketHook(hookID string) error diff --git a/client/iface/mocks/filter.go b/client/iface/mocks/filter.go index 6348e0e77..f00024e38 100644 --- a/client/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -6,6 +6,7 @@ package mocks import ( net "net" + "net/netip" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder { } // AddUDPPacketHook mocks base method. -func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string { +func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(string) diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 250f3ab2e..34c563757 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -2,7 +2,7 @@ package dns import ( "fmt" - "net" + "net/netip" "sync" "github.com/google/gopacket" @@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { return true } - return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil + ip, err := netip.ParseAddr(s.runtimeIP) + if err != nil { + return "", fmt.Errorf("parse runtime ip: %w", err) + } + + return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil } diff --git a/client/server/trace.go b/client/server/trace.go index 66b83d8cf..8b9d375f3 100644 --- a/client/server/trace.go +++ b/client/server/trace.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter" @@ -41,11 +42,21 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) ( srcIP = engine.GetWgAddr() } + srcAddr, ok := netip.AddrFromSlice(srcIP) + if !ok { + return nil, fmt.Errorf("invalid source IP address") + } + dstIP := net.ParseIP(req.GetDestinationIp()) if req.GetDestinationIp() == "self" { dstIP = engine.GetWgAddr() } + dstAddr, ok := netip.AddrFromSlice(dstIP) + if !ok { + return nil, fmt.Errorf("invalid source IP address") + } + if srcIP == nil || dstIP == nil { return nil, fmt.Errorf("invalid IP address") } @@ -85,8 +96,8 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) ( } builder := &uspfilter.PacketBuilder{ - SrcIP: srcIP, - DstIP: dstIP, + SrcIP: srcAddr, + DstIP: dstAddr, Protocol: protocol, SrcPort: uint16(req.GetSourcePort()), DstPort: uint16(req.GetDestinationPort()),