From bcc5824980c4d510e8bd777325f602dce51518e6 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 4 Mar 2025 11:19:42 +0100 Subject: [PATCH] [client] Close userspace firewall properly (#3426) --- client/firewall/iptables/manager_linux.go | 2 +- client/firewall/iptables/manager_linux_test.go | 12 ++++++------ client/firewall/iptables/state_linux.go | 2 +- client/firewall/manager/firewall.go | 4 ++-- client/firewall/nftables/manager_linux.go | 4 ++-- client/firewall/nftables/manager_linux_test.go | 8 ++++---- client/firewall/nftables/router_linux_test.go | 4 ++-- client/firewall/nftables/state_linux.go | 2 +- client/firewall/uspfilter/allow_netbird.go | 8 ++------ .../firewall/uspfilter/allow_netbird_windows.go | 8 ++------ .../firewall/uspfilter/uspfilter_bench_test.go | 16 ++++++++-------- .../firewall/uspfilter/uspfilter_filter_test.go | 4 ++-- client/firewall/uspfilter/uspfilter_test.go | 12 ++++++------ client/internal/acl/manager_test.go | 4 ++-- client/internal/dns/server_test.go | 2 +- client/internal/engine.go | 2 +- 16 files changed, 43 insertions(+), 51 deletions(-) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 929e8a656..144d5a17f 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -166,7 +166,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Reset(stateManager *statemanager.Manager) error { +func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index ba578c033..856633409 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -62,7 +62,7 @@ func TestIptablesManager(t *testing.T) { time.Sleep(time.Second) defer func() { - err := manager.Reset(nil) + err := manager.Close(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) @@ -100,14 +100,14 @@ func TestIptablesManager(t *testing.T) { _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") - err = manager.Reset(nil) + err = manager.Close(nil) require.NoError(t, err, "failed to reset") ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) require.NoError(t, err, "failed check chain exists") if ok { - require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules) + require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules) } }) } @@ -136,7 +136,7 @@ func TestIptablesManagerIPSet(t *testing.T) { time.Sleep(time.Second) defer func() { - err := manager.Reset(nil) + err := manager.Close(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) @@ -166,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) { }) t.Run("reset check", func(t *testing.T) { - err = manager.Reset(nil) + err = manager.Close(nil) require.NoError(t, err, "failed to reset") }) } @@ -204,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) { time.Sleep(time.Second) defer func() { - err := manager.Reset(nil) + err := manager.Close(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index 44b8340ba..2a7120bbf 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -62,7 +62,7 @@ func (s *ShutdownState) Cleanup() error { ipt.aclMgr.ipsetStore = s.ACLIPsetStore } - if err := ipt.Reset(nil); err != nil { + if err := ipt.Close(nil); err != nil { return fmt.Errorf("reset iptables manager: %w", err) } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index d007e20a5..e71328a44 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -94,8 +94,8 @@ type Manager interface { // SetLegacyManagement sets the legacy management mode SetLegacyManagement(legacy bool) error - // Reset firewall to the default state - Reset(stateManager *statemanager.Manager) error + // Close closes the firewall manager + Close(stateManager *statemanager.Manager) error // Flush the changes to firewall controller Flush() error diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index de68f3291..3df9b378d 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { // We only need to record minimal interface state for potential recreation. // Unlike iptables, which requires tracking individual rules, nftables maintains // a known state (our netbird table plus a few static rules). This allows for easy - // cleanup using Reset() without needing to store specific rules. + // cleanup using Close() without needing to store specific rules. if err := stateManager.UpdateState(&ShutdownState{ InterfaceState: &InterfaceState{ NameStr: m.wgIface.Name(), @@ -242,7 +242,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Reset(stateManager *statemanager.Manager) error { +func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index eaa8ef1f5..9ca20889b 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -65,7 +65,7 @@ func TestNftablesManager(t *testing.T) { time.Sleep(time.Second * 3) defer func() { - err = manager.Reset(nil) + err = manager.Close(nil) require.NoError(t, err, "failed to reset") time.Sleep(time.Second) }() @@ -162,7 +162,7 @@ func TestNftablesManager(t *testing.T) { // established rule remains require.Len(t, rules, 1, "expected 1 rules after deletion") - err = manager.Reset(nil) + err = manager.Close(nil) require.NoError(t, err, "failed to reset") } @@ -191,7 +191,7 @@ func TestNFtablesCreatePerformance(t *testing.T) { time.Sleep(time.Second * 3) defer func() { - if err := manager.Reset(nil); err != nil { + if err := manager.Close(nil); err != nil { t.Errorf("clear the manager state: %v", err) } time.Sleep(time.Second) @@ -274,7 +274,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { require.NoError(t, manager.Init(nil)) t.Cleanup(func() { - err := manager.Reset(nil) + err := manager.Close(nil) require.NoError(t, err, "failed to reset manager state") // Verify iptables output after reset diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 2a5d7168d..9081a8349 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) { // need fw manager to init both acl mgr and router for all chains to be present manager, err := Create(ifaceMock) t.Cleanup(func() { - require.NoError(t, manager.Reset(nil)) + require.NoError(t, manager.Close(nil)) }) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { manager, err := Create(ifaceMock) t.Cleanup(func() { - require.NoError(t, manager.Reset(nil)) + require.NoError(t, manager.Close(nil)) }) require.NoError(t, err) require.NoError(t, manager.Init(nil)) diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go index a68c8b8b8..facca1cec 100644 --- a/client/firewall/nftables/state_linux.go +++ b/client/firewall/nftables/state_linux.go @@ -39,7 +39,7 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create nftables manager: %w", err) } - if err := nft.Reset(nil); err != nil { + if err := nft.Close(nil); err != nil { return fmt.Errorf("reset nftables manager: %w", err) } diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 03f23f5e6..aba79bc21 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -8,12 +8,11 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) // Reset firewall to the default state -func (m *Manager) Reset(stateManager *statemanager.Manager) error { +func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -22,17 +21,14 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) } if m.forwarder != nil { @@ -48,7 +44,7 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { } if m.nativeFirewall != nil { - return m.nativeFirewall.Reset(stateManager) + return m.nativeFirewall.Close(stateManager) } return nil } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 379585978..ee540cb1d 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -9,7 +9,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -21,8 +20,8 @@ const ( firewallRuleName = "Netbird" ) -// Reset firewall to the default state -func (m *Manager) Reset(*statemanager.Manager) error { +// Close closes the firewall manager +func (m *Manager) Close(*statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -31,17 +30,14 @@ func (m *Manager) Reset(*statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) } if m.forwarder != nil { diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 875bb2425..bb42a8052 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -160,7 +160,7 @@ func BenchmarkCoreFiltering(b *testing.B) { SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false) defer b.Cleanup(func() { - require.NoError(b, manager.Reset(nil)) + require.NoError(b, manager.Close(nil)) }) manager.wgNetwork = &net.IPNet{ @@ -205,7 +205,7 @@ func BenchmarkStateScaling(b *testing.B) { SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false) b.Cleanup(func() { - require.NoError(b, manager.Reset(nil)) + require.NoError(b, manager.Close(nil)) }) manager.wgNetwork = &net.IPNet{ @@ -253,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false) b.Cleanup(func() { - require.NoError(b, manager.Reset(nil)) + require.NoError(b, manager.Close(nil)) }) manager.wgNetwork = &net.IPNet{ @@ -452,7 +452,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false) b.Cleanup(func() { - require.NoError(b, manager.Reset(nil)) + require.NoError(b, manager.Close(nil)) }) // Setup scenario @@ -579,7 +579,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false) defer b.Cleanup(func() { - require.NoError(b, manager.Reset(nil)) + require.NoError(b, manager.Close(nil)) }) manager.SetNetwork(&net.IPNet{ @@ -670,7 +670,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false) defer b.Cleanup(func() { - require.NoError(b, manager.Reset(nil)) + require.NoError(b, manager.Close(nil)) }) manager.SetNetwork(&net.IPNet{ @@ -789,7 +789,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false) defer b.Cleanup(func() { - require.NoError(b, manager.Reset(nil)) + require.NoError(b, manager.Close(nil)) }) manager.SetNetwork(&net.IPNet{ @@ -877,7 +877,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false) defer b.Cleanup(func() { - require.NoError(b, manager.Reset(nil)) + require.NoError(b, manager.Close(nil)) }) manager.SetNetwork(&net.IPNet{ diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index 9a1456d00..9a5ec9c66 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -39,7 +39,7 @@ func TestPeerACLFiltering(t *testing.T) { require.NotNil(t, manager) t.Cleanup(func() { - require.NoError(t, manager.Reset(nil)) + require.NoError(t, manager.Close(nil)) }) manager.wgNetwork = wgNet @@ -310,7 +310,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { require.False(tb, manager.nativeRouter) tb.Cleanup(func() { - require.NoError(tb, manager.Reset(nil)) + require.NoError(tb, manager.Close(nil)) }) return manager diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 089bf8f55..c03762984 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -254,7 +254,7 @@ func TestManagerReset(t *testing.T) { return } - err = m.Reset(nil) + err = m.Close(nil) if err != nil { t.Errorf("failed to reset Manager: %v", err) return @@ -333,7 +333,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if err = m.Reset(nil); err != nil { + if err = m.Close(nil); err != nil { t.Errorf("failed to reset Manager: %v", err) return } @@ -352,7 +352,7 @@ func TestRemovePacketHook(t *testing.T) { t.Fatalf("Failed to create Manager: %s", err) } defer func() { - require.NoError(t, manager.Reset(nil)) + require.NoError(t, manager.Close(nil)) }() // Add a UDP packet hook @@ -403,7 +403,7 @@ func TestProcessOutgoingHooks(t *testing.T) { manager.udpTracker.Close() manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger) defer func() { - require.NoError(t, manager.Reset(nil)) + require.NoError(t, manager.Close(nil)) }() manager.decoders = sync.Pool{ @@ -484,7 +484,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { time.Sleep(time.Second) defer func() { - if err := manager.Reset(nil); err != nil { + if err := manager.Close(nil); err != nil { t.Errorf("clear the manager state: %v", err) } time.Sleep(time.Second) @@ -530,7 +530,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { }, } defer func() { - require.NoError(t, manager.Reset(nil)) + require.NoError(t, manager.Close(nil)) }() // Set up packet parameters diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 217dbce9f..0327d62ef 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -58,7 +58,7 @@ func TestDefaultManager(t *testing.T) { return } defer func(fw manager.Manager) { - _ = fw.Reset(nil) + _ = fw.Close(nil) }(fw) acl := NewDefaultManager(fw) @@ -352,7 +352,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { return } defer func(fw manager.Manager) { - _ = fw.Reset(nil) + _ = fw.Close(nil) }(fw) acl := NewDefaultManager(fw) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 94b87124b..d9886fcd8 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -1015,7 +1015,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { mh.AssertExpectations(t) } - // Reset mocks + // Close mocks if mh, ok := tc.expectedHandler.(*MockHandler); ok { mh.ExpectedCalls = nil mh.Calls = nil diff --git a/client/internal/engine.go b/client/internal/engine.go index 10c4fb970..943b7cd0b 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1362,7 +1362,7 @@ func (e *Engine) close() { } if e.firewall != nil { - err := e.firewall.Reset(e.stateManager) + err := e.firewall.Close(e.stateManager) if err != nil { log.Warnf("failed to reset firewall: %s", err) }