From 9180fb3f90ff3b6fb2d8671e970e5217137e559e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 27 Jul 2025 15:10:20 +0200 Subject: [PATCH] Eliminate proto strings in route ACLs --- client/firewall/uspfilter/filter.go | 88 ++++++++++--------- .../firewall/uspfilter/filter_bench_test.go | 2 +- .../firewall/uspfilter/filter_filter_test.go | 8 +- client/firewall/uspfilter/filter_test.go | 16 ++-- client/firewall/uspfilter/rule.go | 2 +- client/firewall/uspfilter/tracer.go | 4 +- 6 files changed, 63 insertions(+), 57 deletions(-) diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 3d0b66565..a51af4396 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -27,7 +27,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -const layerTypeAll = 0 +const layerTypeAll = 255 const ( // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. @@ -225,10 +225,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe } func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { - wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) - if err != nil { - return nil, fmt.Errorf("parse wireguard network: %w", err) - } + wgPrefix := iface.Address().Network log.Debugf("blocking invalid routed traffic for %s", wgPrefix) rule, err := m.addRouteFiltering( @@ -402,19 +399,7 @@ func (m *Manager) AddPeerFiltering( r.sPort = sPort r.dPort = dPort - switch proto { - case firewall.ProtocolTCP: - r.protoLayer = layers.LayerTypeTCP - case firewall.ProtocolUDP: - r.protoLayer = layers.LayerTypeUDP - case firewall.ProtocolICMP: - r.protoLayer = layers.LayerTypeICMPv4 - if r.ipLayer == layers.LayerTypeIPv6 { - r.protoLayer = layers.LayerTypeICMPv6 - } - case firewall.ProtocolALL: - r.protoLayer = layerTypeAll - } + r.protoLayer = protoToLayer(proto, r.ipLayer) m.mutex.Lock() if _, ok := m.incomingRules[r.ip]; !ok { @@ -452,16 +437,17 @@ func (m *Manager) addRouteFiltering( } ruleID := uuid.New().String() + rule := RouteRule{ // TODO: consolidate these IDs - id: ruleID, - mgmtId: id, - sources: sources, - dstSet: destination.Set, - proto: proto, - srcPort: sPort, - dstPort: dPort, - action: action, + id: ruleID, + mgmtId: id, + sources: sources, + dstSet: destination.Set, + protoLayer: protoToLayer(proto, layers.LayerTypeIPv4), + srcPort: sPort, + dstPort: dPort, + action: action, } if destination.IsPrefix() { rule.destinations = []netip.Prefix{destination.Prefix} @@ -763,7 +749,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) if blocked { - _, pnum := getProtocolFromPacket(d) + pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", @@ -830,20 +816,22 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe return false } - proto, pnum := getProtocolFromPacket(d) + protoLayer := d.decoded[1] srcPort, dstPort := getPortsFromPacket(d) - ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) + ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort) if !pass { + proto := getProtocolFromPacket(d) + m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - ruleID, pnum, srcIP, srcPort, dstIP, dstPort) + ruleID, proto, srcIP, srcPort, dstIP, dstPort) m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), Type: nftypes.TypeDrop, RuleID: ruleID, Direction: nftypes.Ingress, - Protocol: pnum, + Protocol: proto, SourceIP: srcIP, DestIP: dstIP, SourcePort: srcPort, @@ -872,16 +860,33 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe return true } -func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) { +func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType { + switch proto { + case firewall.ProtocolTCP: + return layers.LayerTypeTCP + case firewall.ProtocolUDP: + return layers.LayerTypeUDP + case firewall.ProtocolICMP: + if ipLayer == layers.LayerTypeIPv6 { + return layers.LayerTypeICMPv6 + } + return layers.LayerTypeICMPv4 + case firewall.ProtocolALL: + return layerTypeAll + } + return 0 +} + +func getProtocolFromPacket(d *decoder) nftypes.Protocol { switch d.decoded[1] { case layers.LayerTypeTCP: - return firewall.ProtocolTCP, nftypes.TCP + return nftypes.TCP case layers.LayerTypeUDP: - return firewall.ProtocolUDP, nftypes.UDP + return nftypes.UDP case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: - return firewall.ProtocolICMP, nftypes.ICMP + return nftypes.ICMP default: - return firewall.ProtocolALL, nftypes.ProtocolUnknown + return nftypes.ProtocolUnknown } } @@ -1049,24 +1054,25 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d } // routeACLsPass returns true if the packet is allowed by the route ACLs -func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) { +func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) { m.mutex.RLock() defer m.mutex.RUnlock() for _, rule := range m.routeRules { - if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches { + if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches { return rule.mgmtId, rule.action == firewall.ActionAccept } } return nil, false } -func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { - if rule.proto != firewall.ProtocolALL && rule.proto != proto { +func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool { + // TODO: handle ipv6 vs ipv4 icmp rules + if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer { return false } - if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP { + if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP { if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) { return false } diff --git a/client/firewall/uspfilter/filter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go index 0cffcc1a7..f35b860b3 100644 --- a/client/firewall/uspfilter/filter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -986,7 +986,7 @@ func BenchmarkRouteACLs(b *testing.B) { for _, tc := range cases { srcIP := netip.MustParseAddr(tc.srcIP) dstIP := netip.MustParseAddr(tc.dstIP) - manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) + manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort) } } } diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index b630c9e66..5fd0082ad 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -1235,7 +1235,7 @@ func TestRouteACLFiltering(t *testing.T) { // testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed // to the forwarder - _, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) + _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort) require.Equal(t, tc.shouldPass, isAllowed) }) } @@ -1421,7 +1421,7 @@ func TestRouteACLOrder(t *testing.T) { srcIP := netip.MustParseAddr(p.srcIP) dstIP := netip.MustParseAddr(p.dstIP) - _, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) + _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) } }) @@ -1464,13 +1464,13 @@ func TestRouteACLSet(t *testing.T) { dstIP := netip.MustParseAddr("192.168.1.100") // Check that traffic is dropped (empty set shouldn't match anything) - _, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) + _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.False(t, isAllowed, "Empty set should not allow any traffic") err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}) require.NoError(t, err) // Now the packet should be allowed - _, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) + _, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") } diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 5b5cd5a53..defba192a 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -737,9 +737,9 @@ func TestUpdateSetMerge(t *testing.T) { dstIP2 := netip.MustParseAddr("192.168.1.100") dstIP3 := netip.MustParseAddr("172.16.0.100") - _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) - _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) - _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80) + _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) + _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) + _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed") @@ -754,8 +754,8 @@ func TestUpdateSetMerge(t *testing.T) { require.NoError(t, err) // Check that all original prefixes are still included - _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) - _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) + _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) + _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update") @@ -763,8 +763,8 @@ func TestUpdateSetMerge(t *testing.T) { dstIP4 := netip.MustParseAddr("172.16.1.100") dstIP5 := netip.MustParseAddr("10.1.0.50") - _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80) - _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80) + _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) + _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed") require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed") @@ -892,7 +892,7 @@ func TestUpdateSetDeduplication(t *testing.T) { srcIP := netip.MustParseAddr("100.10.0.1") for _, tc := range testCases { - _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80) + _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) require.Equal(t, tc.expected, isAllowed, tc.desc) } } diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index b765c72e9..dbe3a7858 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -34,7 +34,7 @@ type RouteRule struct { sources []netip.Prefix dstSet firewall.Set destinations []netip.Prefix - proto firewall.Protocol + protoLayer gopacket.LayerType srcPort *firewall.Port dstPort *firewall.Port action firewall.Action diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index ef04f2700..ba1aa9c3c 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -367,9 +367,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace { } func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace { - proto, _ := getProtocolFromPacket(d) + protoLayer := d.decoded[1] srcPort, dstPort := getPortsFromPacket(d) - id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) + id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort) strId := string(id) if id == nil {