Eliminate proto strings in route ACLs

This commit is contained in:
Viktor Liu
2025-07-27 15:10:20 +02:00
parent 04caf6c026
commit 9180fb3f90
6 changed files with 63 additions and 57 deletions

View File

@ -27,7 +27,7 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
const layerTypeAll = 0 const layerTypeAll = 255
const ( const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
@ -225,10 +225,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) wgPrefix := iface.Address().Network
if err != nil {
return nil, fmt.Errorf("parse wireguard network: %w", err)
}
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering( rule, err := m.addRouteFiltering(
@ -402,19 +399,7 @@ func (m *Manager) AddPeerFiltering(
r.sPort = sPort r.sPort = sPort
r.dPort = dPort r.dPort = dPort
switch proto { r.protoLayer = protoToLayer(proto, r.ipLayer)
case firewall.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
r.protoLayer = layerTypeAll
}
m.mutex.Lock() m.mutex.Lock()
if _, ok := m.incomingRules[r.ip]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
@ -452,13 +437,14 @@ func (m *Manager) addRouteFiltering(
} }
ruleID := uuid.New().String() ruleID := uuid.New().String()
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs // TODO: consolidate these IDs
id: ruleID, id: ruleID,
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
dstSet: destination.Set, dstSet: destination.Set,
proto: proto, protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
srcPort: sPort, srcPort: sPort,
dstPort: dPort, dstPort: dPort,
action: action, action: action,
@ -763,7 +749,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
if blocked { if blocked {
_, pnum := getProtocolFromPacket(d) pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
@ -830,20 +816,22 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return false return false
} }
proto, pnum := getProtocolFromPacket(d) protoLayer := d.decoded[1]
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
if !pass { if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort) ruleID, proto, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeDrop, Type: nftypes.TypeDrop,
RuleID: ruleID, RuleID: ruleID,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: pnum, Protocol: proto,
SourceIP: srcIP, SourceIP: srcIP,
DestIP: dstIP, DestIP: dstIP,
SourcePort: srcPort, SourcePort: srcPort,
@ -872,16 +860,33 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true return true
} }
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) { func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto {
case firewall.ProtocolTCP:
return layers.LayerTypeTCP
case firewall.ProtocolUDP:
return layers.LayerTypeUDP
case firewall.ProtocolICMP:
if ipLayer == layers.LayerTypeIPv6 {
return layers.LayerTypeICMPv6
}
return layers.LayerTypeICMPv4
case firewall.ProtocolALL:
return layerTypeAll
}
return 0
}
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return firewall.ProtocolTCP, nftypes.TCP return nftypes.TCP
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
return firewall.ProtocolUDP, nftypes.UDP return nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return firewall.ProtocolICMP, nftypes.ICMP return nftypes.ICMP
default: default:
return firewall.ProtocolALL, nftypes.ProtocolUnknown return nftypes.ProtocolUnknown
} }
} }
@ -1049,24 +1054,25 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
} }
// routeACLsPass returns true if the packet is allowed by the route ACLs // routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) { func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
for _, rule := range m.routeRules { for _, rule := range m.routeRules {
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches { if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept return rule.mgmtId, rule.action == firewall.ActionAccept
} }
} }
return nil, false return nil, false
} }
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
if rule.proto != firewall.ProtocolALL && rule.proto != proto { // TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false return false
} }
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP { if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) { if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false return false
} }

View File

@ -986,7 +986,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
} }
} }
} }

View File

@ -1235,7 +1235,7 @@ func TestRouteACLFiltering(t *testing.T) {
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed // testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder // to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed) require.Equal(t, tc.shouldPass, isAllowed)
}) })
} }
@ -1421,7 +1421,7 @@ func TestRouteACLOrder(t *testing.T) {
srcIP := netip.MustParseAddr(p.srcIP) srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP) dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
} }
}) })
@ -1464,13 +1464,13 @@ func TestRouteACLSet(t *testing.T) {
dstIP := netip.MustParseAddr("192.168.1.100") dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything) // Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic") require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}) err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err) require.NoError(t, err)
// Now the packet should be allowed // Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) _, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
} }

View File

@ -737,9 +737,9 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP2 := netip.MustParseAddr("192.168.1.100") dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100") dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80) _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed") require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
@ -754,8 +754,8 @@ func TestUpdateSetMerge(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Check that all original prefixes are still included // Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update") require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
@ -763,8 +763,8 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP4 := netip.MustParseAddr("172.16.1.100") dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50") dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80) _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80) _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed") require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed") require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
@ -892,7 +892,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
srcIP := netip.MustParseAddr("100.10.0.1") srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases { for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80) _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc) require.Equal(t, tc.expected, isAllowed, tc.desc)
} }
} }

View File

@ -34,7 +34,7 @@ type RouteRule struct {
sources []netip.Prefix sources []netip.Prefix
dstSet firewall.Set dstSet firewall.Set
destinations []netip.Prefix destinations []netip.Prefix
proto firewall.Protocol protoLayer gopacket.LayerType
srcPort *firewall.Port srcPort *firewall.Port
dstPort *firewall.Port dstPort *firewall.Port
action firewall.Action action firewall.Action

View File

@ -367,9 +367,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
} }
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace { func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d) protoLayer := d.decoded[1]
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
strId := string(id) strId := string(id)
if id == nil { if id == nil {