From 7cd5dcae59bbbbd1074552ea508de38a47e90141 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 18 Aug 2025 11:17:00 +0200 Subject: [PATCH] [client] Fix rule order for deny rules in peer ACLs (#4147) --- client/firewall/iptables/acl_linux.go | 33 +++++-- .../firewall/iptables/manager_linux_test.go | 81 +++++++++++++++- client/firewall/nftables/acl_linux.go | 18 +++- .../firewall/nftables/manager_linux_test.go | 97 +++++++++++++++++-- client/firewall/uspfilter/allow_netbird.go | 1 + .../uspfilter/allow_netbird_windows.go | 1 + client/firewall/uspfilter/filter.go | 74 +++++++++----- .../firewall/uspfilter/filter_filter_test.go | 31 +++++- client/firewall/uspfilter/filter_test.go | 45 +++++++-- client/firewall/uspfilter/tracer.go | 2 +- 10 files changed, 323 insertions(+), 60 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 183417327..7b90000a8 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -85,7 +85,7 @@ func (m *aclManager) AddPeerFiltering( ) ([]firewall.Rule, error) { chain := chainNameInputRules - ipsetName = transformIPsetName(ipsetName, sPort, dPort) + ipsetName = transformIPsetName(ipsetName, sPort, dPort, action) specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) mangleSpecs := slices.Clone(specs) @@ -135,7 +135,14 @@ func (m *aclManager) AddPeerFiltering( return nil, fmt.Errorf("rule already exists") } - if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil { + // Insert DROP rules at the beginning, append ACCEPT rules at the end + if action == firewall.ActionDrop { + // Insert at the beginning of the chain (position 1) + err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...) + } else { + err = m.iptablesClient.Append(tableFilter, chain, specs...) + } + if err != nil { return nil, err } @@ -388,17 +395,25 @@ func actionToStr(action firewall.Action) string { return "DROP" } -func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string { - switch { - case ipsetName == "": +func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string { + if ipsetName == "" { return "" + } + + // Include action in the ipset name to prevent squashing rules with different actions + actionSuffix := "" + if action == firewall.ActionDrop { + actionSuffix = "-drop" + } + + switch { case sPort != nil && dPort != nil: - return ipsetName + "-sport-dport" + return ipsetName + "-sport-dport" + actionSuffix case sPort != nil: - return ipsetName + "-sport" + return ipsetName + "-sport" + actionSuffix case dPort != nil: - return ipsetName + "-dport" + return ipsetName + "-dport" + actionSuffix default: - return ipsetName + return ipsetName + actionSuffix } } diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 30f391a6d..a5cc62feb 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -3,6 +3,7 @@ package iptables import ( "fmt" "net/netip" + "strings" "testing" "time" @@ -15,7 +16,7 @@ import ( var ifaceMock = &iFaceMock{ NameFunc: func() string { - return "lo" + return "wg-test" }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ @@ -109,10 +110,84 @@ func TestIptablesManager(t *testing.T) { }) } +func TestIptablesManagerDenyRules(t *testing.T) { + ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + require.NoError(t, err) + + manager, err := Create(ifaceMock) + require.NoError(t, err) + require.NoError(t, manager.Init(nil)) + + defer func() { + err := manager.Close(nil) + require.NoError(t, err) + }() + + t.Run("add deny rule", func(t *testing.T) { + ip := netip.MustParseAddr("10.20.0.3") + port := &fw.Port{Values: []uint16{22}} + + rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh") + require.NoError(t, err, "failed to add deny rule") + require.NotEmpty(t, rule, "deny rule should not be empty") + + // Verify the rule was added by checking iptables + for _, r := range rule { + rr := r.(*Rule) + checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...) + } + }) + + t.Run("deny rule precedence test", func(t *testing.T) { + ip := netip.MustParseAddr("10.20.0.4") + port := &fw.Port{Values: []uint16{80}} + + // Add accept rule first + _, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http") + require.NoError(t, err, "failed to add accept rule") + + // Add deny rule second for same IP/port - this should take precedence + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http") + require.NoError(t, err, "failed to add deny rule") + + // Inspect the actual iptables rules to verify deny rule comes before accept rule + rules, err := ipv4Client.List("filter", chainNameInputRules) + require.NoError(t, err, "failed to list iptables rules") + + // Debug: print all rules + t.Logf("All iptables rules in chain %s:", chainNameInputRules) + for i, rule := range rules { + t.Logf(" [%d] %s", i, rule) + } + + var denyRuleIndex, acceptRuleIndex int = -1, -1 + for i, rule := range rules { + if strings.Contains(rule, "DROP") { + t.Logf("Found DROP rule at index %d: %s", i, rule) + if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") { + denyRuleIndex = i + } + } + if strings.Contains(rule, "ACCEPT") { + t.Logf("Found ACCEPT rule at index %d: %s", i, rule) + if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") { + acceptRuleIndex = i + } + } + } + + require.NotEqual(t, -1, denyRuleIndex, "deny rule should exist in iptables") + require.NotEqual(t, -1, acceptRuleIndex, "accept rule should exist in iptables") + require.Less(t, denyRuleIndex, acceptRuleIndex, + "deny rule should come before accept rule in iptables chain (deny at index %d, accept at index %d)", + denyRuleIndex, acceptRuleIndex) + }) +} + func TestIptablesManagerIPSet(t *testing.T) { mock := &iFaceMock{ NameFunc: func() string { - return "lo" + return "wg-test" }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ @@ -176,7 +251,7 @@ func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName strin func TestIptablesCreatePerformance(t *testing.T) { mock := &iFaceMock{ NameFunc: func() string { - return "lo" + return "wg-test" }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index b6e9a930b..52979d257 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -341,30 +341,38 @@ func (m *AclManager) addIOFiltering( userData := []byte(ruleId) chain := m.chainInputRules - nftRule := m.rConn.AddRule(&nftables.Rule{ + rule := &nftables.Rule{ Table: m.workTable, Chain: chain, Exprs: mainExpressions, UserData: userData, - }) + } + + // Insert DROP rules at the beginning, append ACCEPT rules at the end + var nftRule *nftables.Rule + if action == firewall.ActionDrop { + nftRule = m.rConn.InsertRule(rule) + } else { + nftRule = m.rConn.AddRule(rule) + } if err := m.rConn.Flush(); err != nil { return nil, fmt.Errorf(flushError, err) } - rule := &Rule{ + ruleStruct := &Rule{ nftRule: nftRule, mangleRule: m.createPreroutingRule(expressions, userData), nftSet: ipset, ruleID: ruleId, ip: ip, } - m.rules[ruleId] = rule + m.rules[ruleId] = ruleStruct if ipset != nil { m.ipsetStore.AddReferenceToIpset(ipset.Name) } - return rule, nil + return ruleStruct, nil } func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule { diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 1dd3e9183..c7f05dcb7 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -2,6 +2,7 @@ package nftables import ( "bytes" + "encoding/binary" "fmt" "net/netip" "os/exec" @@ -20,7 +21,7 @@ import ( var ifaceMock = &iFaceMock{ NameFunc: func() string { - return "lo" + return "wg-test" }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ @@ -103,9 +104,8 @@ func TestNftablesManager(t *testing.T) { Kind: expr.VerdictAccept, }, } - compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) - - expectedExprs2 := []expr.Any{ + // Since DROP rules are inserted at position 0, the DROP rule comes first + expectedDropExprs := []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, @@ -141,7 +141,12 @@ func TestNftablesManager(t *testing.T) { }, &expr.Verdict{Kind: expr.VerdictDrop}, } - require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions") + + // Compare DROP rule at position 0 (inserted first due to InsertRule) + compareExprsIgnoringCounters(t, rules[0].Exprs, expectedDropExprs) + + // Compare connection tracking rule at position 1 (pushed down by DROP rule insertion) + compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1) for _, r := range rule { err = manager.DeletePeerRule(r) @@ -160,10 +165,90 @@ func TestNftablesManager(t *testing.T) { require.NoError(t, err, "failed to reset") } +func TestNftablesManagerRuleOrder(t *testing.T) { + // This test verifies rule insertion order in nftables peer ACLs + // We add accept rule first, then deny rule to test ordering behavior + manager, err := Create(ifaceMock) + require.NoError(t, err) + require.NoError(t, manager.Init(nil)) + + defer func() { + err = manager.Close(nil) + require.NoError(t, err) + }() + + ip := netip.MustParseAddr("100.96.0.2").Unmap() + testClient := &nftables.Conn{} + + // Add accept rule first + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http") + require.NoError(t, err, "failed to add accept rule") + + // Add deny rule second for the same traffic + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http") + require.NoError(t, err, "failed to add deny rule") + + err = manager.Flush() + require.NoError(t, err, "failed to flush") + + rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) + require.NoError(t, err, "failed to get rules") + + t.Logf("Found %d rules in nftables chain", len(rules)) + + // Find the accept and deny rules and verify deny comes before accept + var acceptRuleIndex, denyRuleIndex int = -1, -1 + for i, rule := range rules { + hasAcceptHTTPSet := false + hasDenyHTTPSet := false + hasPort80 := false + var action string + + for _, e := range rule.Exprs { + // Check for set lookup + if lookup, ok := e.(*expr.Lookup); ok { + if lookup.SetName == "accept-http" { + hasAcceptHTTPSet = true + } else if lookup.SetName == "deny-http" { + hasDenyHTTPSet = true + } + } + // Check for port 80 + if cmp, ok := e.(*expr.Cmp); ok { + if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 { + hasPort80 = true + } + } + // Check for verdict + if verdict, ok := e.(*expr.Verdict); ok { + if verdict.Kind == expr.VerdictAccept { + action = "ACCEPT" + } else if verdict.Kind == expr.VerdictDrop { + action = "DROP" + } + } + } + + if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" { + t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i) + acceptRuleIndex = i + } else if hasDenyHTTPSet && hasPort80 && action == "DROP" { + t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i) + denyRuleIndex = i + } + } + + require.NotEqual(t, -1, acceptRuleIndex, "accept rule should exist in nftables") + require.NotEqual(t, -1, denyRuleIndex, "deny rule should exist in nftables") + require.Less(t, denyRuleIndex, acceptRuleIndex, + "deny rule should come before accept rule in nftables chain (deny at index %d, accept at index %d)", + denyRuleIndex, acceptRuleIndex) +} + func TestNFtablesCreatePerformance(t *testing.T) { mock := &iFaceMock{ NameFunc: func() string { - return "lo" + return "wg-test" }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index ce04c82c7..22e6fca1f 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -18,6 +18,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { defer m.mutex.Unlock() m.outgoingRules = make(map[netip.Addr]RuleSet) + m.incomingDenyRules = make(map[netip.Addr]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet) if m.udpTracker != nil { diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index f261c472f..8a56b0862 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -27,6 +27,7 @@ func (m *Manager) Close(*statemanager.Manager) error { defer m.mutex.Unlock() m.outgoingRules = make(map[netip.Addr]RuleSet) + m.incomingDenyRules = make(map[netip.Addr]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet) if m.udpTracker != nil { diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index fdc026b88..7eef49e31 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -70,14 +70,13 @@ func (r RouteRules) Sort() { // Manager userspace firewall manager type Manager struct { - // outgoingRules is used for hooks only - outgoingRules map[netip.Addr]RuleSet - // incomingRules is used for filtering and hooks - incomingRules map[netip.Addr]RuleSet - routeRules RouteRules - decoders sync.Pool - wgIface common.IFaceMapper - nativeFirewall firewall.Manager + outgoingRules map[netip.Addr]RuleSet + incomingDenyRules map[netip.Addr]RuleSet + incomingRules map[netip.Addr]RuleSet + routeRules RouteRules + decoders sync.Pool + wgIface common.IFaceMapper + nativeFirewall firewall.Manager mutex sync.RWMutex @@ -186,6 +185,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe }, nativeFirewall: nativeFirewall, outgoingRules: make(map[netip.Addr]RuleSet), + incomingDenyRules: make(map[netip.Addr]RuleSet), incomingRules: make(map[netip.Addr]RuleSet), wgIface: iface, localipmanager: newLocalIPManager(), @@ -417,10 +417,17 @@ func (m *Manager) AddPeerFiltering( } m.mutex.Lock() - if _, ok := m.incomingRules[r.ip]; !ok { - m.incomingRules[r.ip] = make(RuleSet) + var targetMap map[netip.Addr]RuleSet + if r.drop { + targetMap = m.incomingDenyRules + } else { + targetMap = m.incomingRules } - m.incomingRules[r.ip][r.id] = r + + if _, ok := targetMap[r.ip]; !ok { + targetMap[r.ip] = make(RuleSet) + } + targetMap[r.ip][r.id] = r m.mutex.Unlock() return []firewall.Rule{&r}, nil } @@ -507,10 +514,24 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { return fmt.Errorf("delete rule: invalid rule type: %T", rule) } - if _, ok := m.incomingRules[r.ip][r.id]; !ok { + var sourceMap map[netip.Addr]RuleSet + if r.drop { + sourceMap = m.incomingDenyRules + } else { + sourceMap = m.incomingRules + } + + if ruleset, ok := sourceMap[r.ip]; ok { + if _, exists := ruleset[r.id]; !exists { + return fmt.Errorf("delete rule: no rule with such id: %v", r.id) + } + delete(ruleset, r.id) + if len(ruleset) == 0 { + delete(sourceMap, r.ip) + } + } else { return fmt.Errorf("delete rule: no rule with such id: %v", r.id) } - delete(m.incomingRules[r.ip], r.id) return nil } @@ -572,7 +593,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nil } -// FilterOutBound filters outgoing packets +// FilterOutbound filters outgoing packets func (m *Manager) FilterOutbound(packetData []byte, size int) bool { return m.filterOutbound(packetData, size) } @@ -761,7 +782,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { // handleLocalTraffic handles local traffic. // If it returns true, the packet should be dropped. func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { - ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) + ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData) if blocked { _, pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) @@ -971,26 +992,28 @@ func (m *Manager) isSpecialICMP(d *decoder) bool { icmpType == layers.ICMPv4TypeTimeExceeded } -func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) { +func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) { m.mutex.RLock() defer m.mutex.RUnlock() + if m.isSpecialICMP(d) { return nil, false } - if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok { + if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok { return mgmtId, filter } - if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok { + if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok { + return mgmtId, filter + } + if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok { + return mgmtId, filter + } + if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok { return mgmtId, filter } - if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok { - return mgmtId, filter - } - - // Default policy: DROP ALL return nil, true } @@ -1013,6 +1036,7 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool { func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) { payloadLayer := d.decoded[1] + for _, rule := range rules { if rule.matchByIP && ip.Compare(rule.ip) != 0 { continue @@ -1045,6 +1069,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d return rule.mgmtId, rule.drop, true } } + return nil, false, false } @@ -1116,6 +1141,7 @@ func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook fu m.mutex.Lock() if in { + // Incoming UDP hooks are stored in allow rules map if _, ok := m.incomingRules[r.ip]; !ok { m.incomingRules[r.ip] = make(map[string]PeerRule) } @@ -1136,6 +1162,7 @@ func (m *Manager) RemovePacketHook(hookID string) error { m.mutex.Lock() defer m.mutex.Unlock() + // Check incoming hooks (stored in allow rules) for _, arr := range m.incomingRules { for _, r := range arr { if r.id == hookID { @@ -1144,6 +1171,7 @@ func (m *Manager) RemovePacketHook(hookID string) error { } } } + // Check outgoing hooks for _, arr := range m.outgoingRules { for _, r := range arr { if r.id == hookID { diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index 009860f73..73f3face8 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -458,6 +458,31 @@ func TestPeerACLFiltering(t *testing.T) { ruleAction: fw.ActionDrop, shouldBeBlocked: true, }, + { + name: "Peer ACL - Drop rule should override accept all rule", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 22, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{22}}, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, + { + name: "Peer ACL - Drop all traffic from specific IP", + srcIP: "100.10.0.99", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + ruleIP: "100.10.0.99", + ruleProto: fw.ProtocolALL, + ruleAction: fw.ActionDrop, + shouldBeBlocked: true, + }, } t.Run("Implicit DROP (no rules)", func(t *testing.T) { @@ -468,13 +493,11 @@ func TestPeerACLFiltering(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if tc.ruleAction == fw.ActionDrop { - // add general accept rule to test drop rule - // TODO: this only works because 0.0.0.0 is tested last, we need to implement order + // add general accept rule for the same IP to test drop rule precedence rules, err := manager.AddPeerFiltering( nil, - net.ParseIP("0.0.0.0"), + net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 3197be4e8..bac06814d 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -136,9 +136,22 @@ func TestManagerDeleteRule(t *testing.T) { return } + // Check rules exist in appropriate maps for _, r := range rule2 { - if _, ok := m.incomingRules[ip][r.ID()]; !ok { - t.Errorf("rule2 is not in the incomingRules") + peerRule, ok := r.(*PeerRule) + if !ok { + t.Errorf("rule should be a PeerRule") + continue + } + // Check if rule exists in deny or allow maps based on action + var found bool + if peerRule.drop { + _, found = m.incomingDenyRules[ip][r.ID()] + } else { + _, found = m.incomingRules[ip][r.ID()] + } + if !found { + t.Errorf("rule2 is not in the expected rules map") } } @@ -150,9 +163,22 @@ func TestManagerDeleteRule(t *testing.T) { } } + // Check rules are removed from appropriate maps for _, r := range rule2 { - if _, ok := m.incomingRules[ip][r.ID()]; ok { - t.Errorf("rule2 is not in the incomingRules") + peerRule, ok := r.(*PeerRule) + if !ok { + t.Errorf("rule should be a PeerRule") + continue + } + // Check if rule is removed from deny or allow maps based on action + var found bool + if peerRule.drop { + _, found = m.incomingDenyRules[ip][r.ID()] + } else { + _, found = m.incomingRules[ip][r.ID()] + } + if found { + t.Errorf("rule2 should be removed from the rules map") } } } @@ -196,16 +222,17 @@ func TestAddUDPPacketHook(t *testing.T) { var addedRule PeerRule if tt.in { + // Incoming UDP hooks are stored in allow rules map if len(manager.incomingRules[tt.ip]) != 1 { - t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) + t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip])) return } for _, rule := range manager.incomingRules[tt.ip] { addedRule = rule } } else { - if len(manager.outgoingRules) != 1 { - t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) + if len(manager.outgoingRules[tt.ip]) != 1 { + t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip])) return } for _, rule := range manager.outgoingRules[tt.ip] { @@ -261,8 +288,8 @@ func TestManagerReset(t *testing.T) { return } - if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 { - t.Errorf("rules is not empty") + if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 { + t.Errorf("rules are not empty") } } diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index ef04f2700..c75c0249d 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -314,7 +314,7 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string { func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { trace.AddResult(StageRouting, "Packet destined for local delivery", true) - ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) + ruleId, blocked := m.peerACLsBlock(srcIP, d, packetData) strRuleId := "" if ruleId != nil {