From 8016710d241efb2b8dee03ff317128d3cec198b8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:46:24 +0200 Subject: [PATCH] [client] Cleanup firewall state on startup (#2768) --- client/firewall/create.go | 4 +- client/firewall/create_linux.go | 90 +++++++----- client/firewall/iptables/acl_linux.go | 79 ++++++++--- client/firewall/iptables/manager_linux.go | 71 +++++++--- .../firewall/iptables/manager_linux_test.go | 20 +-- client/firewall/iptables/router_linux.go | 93 ++++++++++--- client/firewall/iptables/router_linux_test.go | 13 +- client/firewall/iptables/rulestore_linux.go | 57 +++++++- client/firewall/iptables/state_linux.go | 70 ++++++++++ client/firewall/manager/firewall.go | 6 +- client/firewall/nftables/acl_linux.go | 24 +--- client/firewall/nftables/manager_linux.go | 131 ++++++++++++++---- .../firewall/nftables/manager_linux_test.go | 13 +- client/firewall/nftables/router_linux.go | 28 ++-- client/firewall/nftables/router_linux_test.go | 13 +- client/firewall/nftables/state_linux.go | 47 +++++++ client/firewall/uspfilter/allow_netbird.go | 6 +- .../uspfilter/allow_netbird_windows.go | 4 +- client/firewall/uspfilter/uspfilter.go | 7 +- client/firewall/uspfilter/uspfilter_test.go | 6 +- client/internal/acl/manager_test.go | 9 +- client/internal/connect.go | 11 +- client/internal/dns/server.go | 7 + client/internal/engine.go | 5 +- .../routemanager/refcounter/refcounter.go | 40 ++++++ .../internal/routemanager/systemops/state.go | 75 ++-------- .../systemops/systemops_generic.go | 22 +-- client/internal/statemanager/path.go | 4 +- client/server/server.go | 33 ----- client/server/state.go | 37 +++++ client/server/state_generic.go | 14 ++ client/server/state_linux.go | 18 +++ 32 files changed, 739 insertions(+), 318 deletions(-) create mode 100644 client/firewall/iptables/state_linux.go create mode 100644 client/firewall/nftables/state_linux.go create mode 100644 client/server/state.go create mode 100644 client/server/state_generic.go create mode 100644 client/server/state_linux.go diff --git a/client/firewall/create.go b/client/firewall/create.go index 86ce94cea..9466f4b4d 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -3,7 +3,6 @@ package firewall import ( - "context" "fmt" "runtime" @@ -11,10 +10,11 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) // NewFirewall creates a firewall manager instance -func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 92deb63dc..c853548f8 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -3,7 +3,6 @@ package firewall import ( - "context" "fmt" "os" @@ -15,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -32,54 +32,72 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - var fm firewall.Manager - var errFw error + fm, errFw := createNativeFirewall(iface) - switch check() { - case IPTABLES: - log.Info("creating an iptables firewall manager") - fm, errFw = nbiptables.Create(context, iface) - if errFw != nil { - log.Errorf("failed to create iptables manager: %s", errFw) + if fm != nil { + if err := fm.Init(stateManager); err != nil { + log.Errorf("failed to init nftables manager: %s", err) } - case NFTABLES: - log.Info("creating an nftables firewall manager") - fm, errFw = nbnftables.Create(context, iface) - if errFw != nil { - log.Errorf("failed to create nftables manager: %s", errFw) - } - default: - errFw = fmt.Errorf("no firewall manager found") - log.Info("no firewall manager found, trying to use userspace packet filtering firewall") } if iface.IsUserspaceBind() { - var errUsp error - if errFw == nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) - } else { - fm, errUsp = uspfilter.Create(iface) - } - if errUsp != nil { - log.Debugf("failed to create userspace filtering firewall: %s", errUsp) - return nil, errUsp - } - - if err := fm.AllowNetbird(); err != nil { - log.Errorf("failed to allow netbird interface traffic: %v", err) - } - return fm, nil + return createUserspaceFirewall(iface, fm, errFw) } - if errFw != nil { - return nil, errFw + return fm, errFw +} + +func createNativeFirewall(iface IFaceMapper) (firewall.Manager, error) { + switch check() { + case IPTABLES: + return createIptablesFirewall(iface) + case NFTABLES: + return createNftablesFirewall(iface) + default: + log.Info("no firewall manager found, trying to use userspace packet filtering firewall") + return nil, fmt.Errorf("no firewall manager found") + } +} + +func createIptablesFirewall(iface IFaceMapper) (firewall.Manager, error) { + log.Info("creating an iptables firewall manager") + fm, err := nbiptables.Create(iface) + if err != nil { + log.Errorf("failed to create iptables manager: %s", err) + } + return fm, err +} + +func createNftablesFirewall(iface IFaceMapper) (firewall.Manager, error) { + log.Info("creating an nftables firewall manager") + fm, err := nbnftables.Create(iface) + if err != nil { + log.Errorf("failed to create nftables manager: %s", err) + } + return fm, err +} + +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, errFw error) (firewall.Manager, error) { + var errUsp error + if errFw == nil { + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) + } else { + fm, errUsp = uspfilter.Create(iface) } + if errUsp != nil { + log.Debugf("failed to create userspace filtering firewall: %s", errUsp) + return nil, errUsp + } + + if err := fm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) + } return fm, nil } diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index c271e592d..5cd69245b 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -22,6 +23,8 @@ const ( chainNameOutputRules = "NETBIRD-ACL-OUTPUT" ) +type aclEntries map[string][][]string + type entry struct { spec []string position int @@ -32,9 +35,11 @@ type aclManager struct { wgIface iFaceMapper routingFwChainName string - entries map[string][][]string + entries aclEntries optionalEntries map[string][]entry ipsetStore *ipsetStore + + stateManager *statemanager.Manager } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { @@ -48,24 +53,30 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi ipsetStore: newIpsetStore(), } - err := ipset.Init() - if err != nil { - return nil, fmt.Errorf("failed to init ipset: %w", err) + if err := ipset.Init(); err != nil { + return nil, fmt.Errorf("init ipset: %w", err) } + return m, nil +} + +func (m *aclManager) init(stateManager *statemanager.Manager) error { + m.stateManager = stateManager + m.seedInitialEntries() m.seedInitialOptionalEntries() - err = m.cleanChains() - if err != nil { - return nil, err + if err := m.cleanChains(); err != nil { + return fmt.Errorf("clean chains: %w", err) } - err = m.createDefaultChains() - if err != nil { - return nil, err + if err := m.createDefaultChains(); err != nil { + return fmt.Errorf("create default chains: %w", err) } - return m, nil + + m.updateState() + + return nil } func (m *aclManager) AddPeerFiltering( @@ -146,6 +157,8 @@ func (m *aclManager) AddPeerFiltering( chain: chain, } + m.updateState() + return []firewall.Rule{rule}, nil } @@ -180,15 +193,23 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { } } - err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) - if err != nil { - log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err) + if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { + return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err) } - return err + + m.updateState() + + return nil } func (m *aclManager) Reset() error { - return m.cleanChains() + if err := m.cleanChains(); err != nil { + return fmt.Errorf("clean chains: %w", err) + } + + m.updateState() + + return nil } // todo write less destructive cleanup mechanism @@ -348,6 +369,32 @@ func (m *aclManager) appendToEntries(chainName string, spec []string) { m.entries[chainName] = append(m.entries[chainName], spec) } +func (m *aclManager) updateState() { + if m.stateManager == nil { + return + } + + var currentState *ShutdownState + if existing := m.stateManager.GetState(currentState); existing != nil { + if existingState, ok := existing.(*ShutdownState); ok { + currentState = existingState + } + } + if currentState == nil { + currentState = &ShutdownState{} + } + + currentState.Lock() + defer currentState.Unlock() + + currentState.ACLEntries = m.entries + currentState.ACLIPsetStore = m.ipsetStore + + if err := m.stateManager.UpdateState(currentState); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + // filterRuleSpecs returns the specs of a filtering rule func filterRuleSpecs( ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 94bd2fccf..a59bd2c60 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -8,10 +8,13 @@ import ( "sync" "github.com/coreos/go-iptables/iptables" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/statemanager" ) // Manager of iptables firewall @@ -33,10 +36,10 @@ type iFaceMapper interface { } // Create iptables firewall manager -func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper) (*Manager, error) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { - return nil, fmt.Errorf("iptables is not installed in the system or not supported") + return nil, fmt.Errorf("init iptables: %w", err) } m := &Manager{ @@ -44,20 +47,49 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouter(context, iptablesClient, wgIface) + m.router, err = newRouter(iptablesClient, wgIface) if err != nil { - log.Debugf("failed to initialize route related chains: %s", err) - return nil, err + return nil, fmt.Errorf("create router: %w", err) } + m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) if err != nil { - log.Debugf("failed to initialize ACL manager: %s", err) - return nil, err + return nil, fmt.Errorf("create acl manager: %w", err) } return m, nil } +func (m *Manager) Init(stateManager *statemanager.Manager) error { + state := &ShutdownState{ + InterfaceState: &InterfaceState{ + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + UserspaceBind: m.wgIface.IsUserspaceBind(), + }, + } + stateManager.RegisterState(state) + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update state: %v", err) + } + + if err := m.router.init(stateManager); err != nil { + return fmt.Errorf("router init: %w", err) + } + + if err := m.aclMgr.init(stateManager); err != nil { + // TODO: cleanup router + return fmt.Errorf("acl manager init: %w", err) + } + + // persist early to ensure cleanup of chains + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + + return nil +} + // AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported @@ -133,20 +165,27 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - errAcl := m.aclMgr.Reset() - if errAcl != nil { - log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl) + var merr *multierror.Error + + if err := m.aclMgr.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) } - errMgr := m.router.Reset() - if errMgr != nil { - log.Errorf("failed to clean up router rules from firewall: %s", errMgr) - return errMgr + if err := m.router.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) } - return errAcl + + // attempt to delete state only if all other operations succeeded + if merr == nil { + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 498d8f58b..ebdb83137 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -1,7 +1,6 @@ package iptables import ( - "context" "fmt" "net" "testing" @@ -56,13 +55,14 @@ func TestIptablesManager(t *testing.T) { require.NoError(t, err) // just check on the local interface - manager, err := Create(context.Background(), ifaceMock) + manager, err := Create(ifaceMock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) @@ -122,7 +122,7 @@ func TestIptablesManager(t *testing.T) { _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) @@ -154,13 +154,14 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) @@ -219,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) { }) t.Run("reset check", func(t *testing.T) { - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") }) } @@ -251,12 +252,13 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 129323928..90811ae11 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -3,7 +3,6 @@ package iptables import ( - "context" "fmt" "net/netip" "strconv" @@ -18,6 +17,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -48,28 +48,31 @@ type routeFilteringRuleParams struct { SetName string } +type routeRules map[string][]string + +type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] + type router struct { - ctx context.Context - stop context.CancelFunc iptablesClient *iptables.IPTables - rules map[string][]string - ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}] + rules routeRules + ipsetCounter *ipsetCounter wgIface iFaceMapper legacyManagement bool + + stateManager *statemanager.Manager } -func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) +func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { r := &router{ - ctx: ctx, - stop: cancel, iptablesClient: iptablesClient, rules: make(map[string][]string), wgIface: wgIface, } r.ipsetCounter = refcounter.New( - r.createIpSet, + func(name string, sources []netip.Prefix) (struct{}, error) { + return struct{}{}, r.createIpSet(name, sources) + }, func(name string, _ struct{}) error { return r.deleteIpSet(name) }, @@ -79,16 +82,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI return nil, fmt.Errorf("init ipset: %w", err) } - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("cleanup routing rules: %s", err) - return nil, err + return r, nil +} + +func (r *router) init(stateManager *statemanager.Manager) error { + r.stateManager = stateManager + + if err := r.cleanUpDefaultForwardRules(); err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.createContainers() - if err != nil { - log.Errorf("create containers for route: %s", err) + + if err := r.createContainers(); err != nil { + return fmt.Errorf("create containers: %w", err) } - return r, err + + r.updateState() + + return nil } func (r *router) AddRouteFiltering( @@ -129,6 +139,8 @@ func (r *router) AddRouteFiltering( r.rules[string(ruleKey)] = rule + r.updateState() + return ruleKey, nil } @@ -152,6 +164,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { log.Debugf("route rule %s not found", ruleKey) } + r.updateState() + return nil } @@ -164,18 +178,18 @@ func (r *router) findSetNameInRule(rule []string) string { return "" } -func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) { +func (r *router) createIpSet(setName string, sources []netip.Prefix) error { if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { - return struct{}{}, fmt.Errorf("create set %s: %w", setName, err) + return fmt.Errorf("create set %s: %w", setName, err) } for _, prefix := range sources { if err := ipset.AddPrefix(setName, prefix); err != nil { - return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err) + return fmt.Errorf("add element to set %s: %w", setName, err) } } - return struct{}{}, nil + return nil } func (r *router) deleteIpSet(setName string) error { @@ -206,6 +220,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { return fmt.Errorf("add inverse nat rule: %w", err) } + r.updateState() + return nil } @@ -223,6 +239,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf("remove legacy routing rule: %w", err) } + r.updateState() + return nil } @@ -280,6 +298,9 @@ func (r *router) RemoveAllLegacyRouteRules() error { merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) } } + + r.updateState() + return nberrors.FormatErrorOrNil(merr) } @@ -294,6 +315,8 @@ func (r *router) Reset() error { merr = multierror.Append(merr, err) } + r.updateState() + return nberrors.FormatErrorOrNil(merr) } @@ -431,6 +454,32 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { return nil } +func (r *router) updateState() { + if r.stateManager == nil { + return + } + + var currentState *ShutdownState + if existing := r.stateManager.GetState(currentState); existing != nil { + if existingState, ok := existing.(*ShutdownState); ok { + currentState = existingState + } + } + if currentState == nil { + currentState = &ShutdownState{} + } + + currentState.Lock() + defer currentState.Unlock() + + currentState.RouteRules = r.rules + currentState.RouteIPsetCounter = r.ipsetCounter + + if err := r.stateManager.UpdateState(currentState); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { intdir := "-i" lointdir := "-o" diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 6cede09e2..2d821a9db 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -3,7 +3,6 @@ package iptables import ( - "context" "net/netip" "os/exec" "testing" @@ -30,8 +29,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "should return a valid iptables manager") + require.NoError(t, manager.init(nil)) defer func() { _ = manager.Reset() @@ -74,8 +74,9 @@ func TestIptablesManager_AddNatRule(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") + require.NoError(t, manager.init(nil)) defer func() { err := manager.Reset() @@ -132,8 +133,9 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") + require.NoError(t, manager.init(nil)) defer func() { _ = manager.Reset() }() @@ -183,8 +185,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "Failed to create iptables client") - r, err := newRouter(context.Background(), iptablesClient, ifaceMock) + r, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "Failed to create router manager") + require.NoError(t, r.init(nil)) defer func() { err := r.Reset() diff --git a/client/firewall/iptables/rulestore_linux.go b/client/firewall/iptables/rulestore_linux.go index a9470c9ac..bfd08bee2 100644 --- a/client/firewall/iptables/rulestore_linux.go +++ b/client/firewall/iptables/rulestore_linux.go @@ -1,14 +1,16 @@ package iptables +import "encoding/json" + type ipList struct { ips map[string]struct{} } -func newIpList(ip string) ipList { +func newIpList(ip string) *ipList { ips := make(map[string]struct{}) ips[ip] = struct{}{} - return ipList{ + return &ipList{ ips: ips, } } @@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) { s.ips[ip] = struct{}{} } +// MarshalJSON implements json.Marshaler +func (s *ipList) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + IPs map[string]struct{} `json:"ips"` + }{ + IPs: s.ips, + }) +} + +// UnmarshalJSON implements json.Unmarshaler +func (s *ipList) UnmarshalJSON(data []byte) error { + temp := struct { + IPs map[string]struct{} `json:"ips"` + }{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + s.ips = temp.IPs + return nil +} + type ipsetStore struct { - ipsets map[string]ipList // ipsetName -> ruleset + ipsets map[string]*ipList } func newIpsetStore() *ipsetStore { return &ipsetStore{ - ipsets: make(map[string]ipList), + ipsets: make(map[string]*ipList), } } -func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) { +func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) { r, ok := s.ipsets[ipsetName] return r, ok } -func (s *ipsetStore) addIpList(ipsetName string, list ipList) { +func (s *ipsetStore) addIpList(ipsetName string, list *ipList) { s.ipsets[ipsetName] = list } func (s *ipsetStore) deleteIpset(ipsetName string) { - s.ipsets[ipsetName] = ipList{} delete(s.ipsets, ipsetName) } @@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string { } return names } + +// MarshalJSON implements json.Marshaler +func (s *ipsetStore) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + IPSets map[string]*ipList `json:"ipsets"` + }{ + IPSets: s.ipsets, + }) +} + +// UnmarshalJSON implements json.Unmarshaler +func (s *ipsetStore) UnmarshalJSON(data []byte) error { + temp := struct { + IPSets map[string]*ipList `json:"ipsets"` + }{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + s.ipsets = temp.IPSets + return nil +} diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go new file mode 100644 index 000000000..44b8340ba --- /dev/null +++ b/client/firewall/iptables/state_linux.go @@ -0,0 +1,70 @@ +package iptables + +import ( + "fmt" + "sync" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +type InterfaceState struct { + NameStr string `json:"name"` + WGAddress iface.WGAddress `json:"wg_address"` + UserspaceBind bool `json:"userspace_bind"` +} + +func (i *InterfaceState) Name() string { + return i.NameStr +} + +func (i *InterfaceState) Address() device.WGAddress { + return i.WGAddress +} + +func (i *InterfaceState) IsUserspaceBind() bool { + return i.UserspaceBind +} + +type ShutdownState struct { + sync.Mutex + + InterfaceState *InterfaceState `json:"interface_state,omitempty"` + + RouteRules routeRules `json:"route_rules,omitempty"` + RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"` + + ACLEntries aclEntries `json:"acl_entries,omitempty"` + ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` +} + +func (s *ShutdownState) Name() string { + return "iptables_state" +} + +func (s *ShutdownState) Cleanup() error { + ipt, err := Create(s.InterfaceState) + if err != nil { + return fmt.Errorf("create iptables manager: %w", err) + } + + if s.RouteRules != nil { + ipt.router.rules = s.RouteRules + } + if s.RouteIPsetCounter != nil { + ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter) + } + + if s.ACLEntries != nil { + ipt.aclMgr.entries = s.ACLEntries + } + if s.ACLIPsetStore != nil { + ipt.aclMgr.ipsetStore = s.ACLIPsetStore + } + + if err := ipt.Reset(nil); err != nil { + return fmt.Errorf("reset iptables manager: %w", err) + } + + return nil +} diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 556bda0d6..2a40cd9f6 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -10,6 +10,8 @@ import ( "strings" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -52,6 +54,8 @@ const ( // It declares methods which handle actions required by the // Netbird client for ACL and routing functionality type Manager interface { + Init(stateManager *statemanager.Manager) error + // AllowNetbird allows netbird interface traffic AllowNetbird() error @@ -91,7 +95,7 @@ type Manager interface { SetLegacyManagement(legacy bool) error // Reset firewall to the default state - Reset() error + Reset(stateManager *statemanager.Manager) error // Flush the changes to firewall controller Flush() error diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 61434f035..ca7b2e59f 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -17,7 +17,6 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -56,13 +55,6 @@ type AclManager struct { rules map[string]*Rule } -// iFaceMapper defines subset methods of interface required for manager -type iFaceMapper interface { - Name() string - Address() iface.WGAddress - IsUserspaceBind() bool -} - func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) @@ -70,10 +62,10 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam // overloads netlink with high amount of rules ( > 10000) sConn, err := nftables.New(nftables.AsLasting()) if err != nil { - return nil, err + return nil, fmt.Errorf("create nf conn: %w", err) } - m := &AclManager{ + return &AclManager{ rConn: &nftables.Conn{}, sConn: sConn, wgIface: wgIface, @@ -82,14 +74,12 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), - } + }, nil +} - err = m.createDefaultChains() - if err != nil { - return nil, err - } - - return m, nil +func (m *AclManager) init(workTable *nftables.Table) error { + m.workTable = workTable + return m.createDefaultChains() } // AddPeerFiltering rule to the firewall diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 01b08bd71..a4650f3b6 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -14,6 +14,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -24,6 +26,13 @@ const ( chainNameInput = "INPUT" ) +// iFaceMapper defines subset methods of interface required for manager +type iFaceMapper interface { + Name() string + Address() iface.WGAddress + IsUserspaceBind() bool +} + // Manager of iptables firewall type Manager struct { mutex sync.Mutex @@ -35,30 +44,68 @@ type Manager struct { } // Create nftables firewall manager -func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper) (*Manager, error) { m := &Manager{ rConn: &nftables.Conn{}, wgIface: wgIface, } - workTable, err := m.createWorkTable() - if err != nil { - return nil, err - } + workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} - m.router, err = newRouter(context, workTable, wgIface) + var err error + m.router, err = newRouter(workTable, wgIface) if err != nil { - return nil, err + return nil, fmt.Errorf("create router: %w", err) } m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) if err != nil { - return nil, err + return nil, fmt.Errorf("create acl manager: %w", err) } return m, nil } +// Init nftables firewall manager +func (m *Manager) Init(stateManager *statemanager.Manager) error { + workTable, err := m.createWorkTable() + if err != nil { + return fmt.Errorf("create work table: %w", err) + } + + if err := m.router.init(workTable); err != nil { + return fmt.Errorf("router init: %w", err) + } + + if err := m.aclManager.init(workTable); err != nil { + // TODO: cleanup router + return fmt.Errorf("acl manager init: %w", err) + } + + stateManager.RegisterState(&ShutdownState{}) + + // We only need to record minimal interface state for potential recreation. + // Unlike iptables, which requires tracking individual rules, nftables maintains + // a known state (our netbird table plus a few static rules). This allows for easy + // cleanup using Reset() without needing to store specific rules. + if err := stateManager.UpdateState(&ShutdownState{ + InterfaceState: &InterfaceState{ + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + UserspaceBind: m.wgIface.IsUserspaceBind(), + }, + }); err != nil { + log.Errorf("failed to update state: %v", err) + } + + // persist early + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + + return nil +} + // AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set @@ -203,48 +250,80 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - chains, err := m.rConn.ListChains() - if err != nil { - return fmt.Errorf("list of chains: %w", err) + if err := m.resetNetbirdInputRules(); err != nil { + return fmt.Errorf("reset netbird input rules: %v", err) } + if err := m.router.Reset(); err != nil { + return fmt.Errorf("reset router: %v", err) + } + + if err := m.cleanupNetbirdTables(); err != nil { + return fmt.Errorf("cleanup netbird tables: %v", err) + } + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + return fmt.Errorf("delete state: %v", err) + } + + return nil +} + +func (m *Manager) resetNetbirdInputRules() error { + chains, err := m.rConn.ListChains() + if err != nil { + return fmt.Errorf("list chains: %w", err) + } + + m.deleteNetbirdInputRules(chains) + + return nil +} + +func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { for _, c := range chains { - // delete Netbird allow input traffic rule if it exists if c.Table.Name == "filter" && c.Name == "INPUT" { rules, err := m.rConn.GetRules(c.Table, c) if err != nil { log.Errorf("get rules for chain %q: %v", c.Name, err) continue } - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } + + m.deleteMatchingRules(rules) + } + } +} + +func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) { + for _, r := range rules { + if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { + if err := m.rConn.DelRule(r); err != nil { + log.Errorf("delete rule: %v", err) } } } +} - if err := m.router.Reset(); err != nil { - return fmt.Errorf("reset forward rules: %v", err) - } - +func (m *Manager) cleanupNetbirdTables() error { tables, err := m.rConn.ListTables() if err != nil { - return fmt.Errorf("list of tables: %w", err) + return fmt.Errorf("list tables: %w", err) } + for _, t := range tables { if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } - - return m.rConn.Flush() + return nil } // Flush rule/chain/set operations from the buffer diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index bbe18ab07..77f4f0306 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -1,7 +1,6 @@ package nftables import ( - "context" "fmt" "net" "net/netip" @@ -58,12 +57,13 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), ifaceMock) + manager, err := Create(ifaceMock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) defer func() { - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") time.Sleep(time.Second) }() @@ -169,7 +169,7 @@ func TestNftablesManager(t *testing.T) { // established rule remains require.Len(t, rules, 1, "expected 1 rules after deletion") - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") } @@ -192,12 +192,13 @@ func TestNFtablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) defer func() { - if err := manager.Reset(); err != nil { + if err := manager.Reset(nil); err != nil { t.Errorf("clear the manager state: %v", err) } time.Sleep(time.Second) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 03526fee7..9b28e4eb2 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -2,7 +2,6 @@ package nftables import ( "bytes" - "context" "encoding/binary" "errors" "fmt" @@ -40,8 +39,6 @@ var ( ) type router struct { - ctx context.Context - stop context.CancelFunc conn *nftables.Conn workTable *nftables.Table filterTable *nftables.Table @@ -54,12 +51,8 @@ type router struct { legacyManagement bool } -func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) - +func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { r := &router{ - ctx: ctx, - stop: cancel, conn: &nftables.Conn{}, workTable: workTable, chains: make(map[string]*nftables.Chain), @@ -78,20 +71,25 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa if errors.Is(err, errFilterTableNotFound) { log.Warnf("table 'filter' not found for forward rules") } else { - return nil, err + return nil, fmt.Errorf("load filter table: %w", err) } } - err = r.removeAcceptForwardRules() - if err != nil { + return r, nil +} + +func (r *router) init(workTable *nftables.Table) error { + r.workTable = workTable + + if err := r.removeAcceptForwardRules(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.createContainers() - if err != nil { - log.Errorf("failed to create containers for route: %s", err) + if err := r.createContainers(); err != nil { + return fmt.Errorf("create containers: %w", err) } - return r, err + + return nil } // Reset cleans existing nftables default forward rules from the system diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index c07111b4e..19ed48991 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -3,7 +3,6 @@ package nftables import ( - "context" "encoding/binary" "net/netip" "os/exec" @@ -40,8 +39,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) { for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table, ifaceMock) + manager, err := newRouter(table, ifaceMock) require.NoError(t, err, "failed to create router") + require.NoError(t, manager.init(table)) nftablesTestingClient := &nftables.Conn{} @@ -142,8 +142,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table, ifaceMock) + manager, err := newRouter(table, ifaceMock) require.NoError(t, err, "failed to create router") + require.NoError(t, manager.init(table)) nftablesTestingClient := &nftables.Conn{} @@ -210,8 +211,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(context.Background(), workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock) require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) defer func(r *router) { require.NoError(t, r.Reset(), "Failed to reset rules") @@ -376,8 +378,9 @@ func TestNftablesCreateIpSet(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(context.Background(), workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock) require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) defer func() { require.NoError(t, r.Reset(), "Failed to reset router") diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go new file mode 100644 index 000000000..a68c8b8b8 --- /dev/null +++ b/client/firewall/nftables/state_linux.go @@ -0,0 +1,47 @@ +package nftables + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +type InterfaceState struct { + NameStr string `json:"name"` + WGAddress iface.WGAddress `json:"wg_address"` + UserspaceBind bool `json:"userspace_bind"` +} + +func (i *InterfaceState) Name() string { + return i.NameStr +} + +func (i *InterfaceState) Address() device.WGAddress { + return i.WGAddress +} + +func (i *InterfaceState) IsUserspaceBind() bool { + return i.UserspaceBind +} + +type ShutdownState struct { + InterfaceState *InterfaceState `json:"interface_state,omitempty"` +} + +func (s *ShutdownState) Name() string { + return "nftables_state" +} + +func (s *ShutdownState) Cleanup() error { + nft, err := Create(s.InterfaceState) + if err != nil { + return fmt.Errorf("create nftables manager: %w", err) + } + + if err := nft.Reset(nil); err != nil { + return fmt.Errorf("reset nftables manager: %w", err) + } + + return nil +} diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 2275dad39..cefc81a3c 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -2,8 +2,10 @@ package uspfilter +import "github.com/netbirdio/netbird/client/internal/statemanager" + // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -11,7 +13,7 @@ func (m *Manager) Reset() error { m.incomingRules = make(map[string]RuleSet) if m.nativeFirewall != nil { - return m.nativeFirewall.Reset() + return m.nativeFirewall.Reset(stateManager) } return nil } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 34274564f..d3732301e 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -6,6 +6,8 @@ import ( "syscall" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) type action string @@ -17,7 +19,7 @@ const ( ) // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(*statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 0e3ee9799..3829a9baf 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -14,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const layerTypeAll = 0 @@ -97,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) { return m, nil } +func (m *Manager) Init(*statemanager.Manager) error { + return nil +} + func (m *Manager) IsServerRouteSupported() bool { if m.nativeFirewall == nil { return false @@ -190,7 +195,7 @@ func (m *Manager) AddPeerFiltering( return []firewall.Rule{&r}, nil } -func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { +func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { if m.nativeFirewall == nil { return nil, errRouteNotSupported } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index c188deea4..d7c93cb7f 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -259,7 +259,7 @@ func TestManagerReset(t *testing.T) { return } - err = m.Reset() + err = m.Reset(nil) if err != nil { t.Errorf("failed to reset Manager: %v", err) return @@ -330,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if err = m.Reset(); err != nil { + if err = m.Reset(nil); err != nil { t.Errorf("failed to reset Manager: %v", err) return } @@ -396,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { time.Sleep(time.Second) defer func() { - if err := manager.Reset(); err != nil { + if err := manager.Reset(nil); err != nil { t.Errorf("clear the manager state: %v", err) } time.Sleep(time.Second) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 7d999669a..9a766021a 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -1,7 +1,6 @@ package acl import ( - "context" "net" "testing" @@ -52,13 +51,13 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(context.Background(), ifaceMock) + fw, err := firewall.NewFirewall(ifaceMock, nil) if err != nil { t.Errorf("create firewall: %v", err) return } defer func(fw manager.Manager) { - _ = fw.Reset() + _ = fw.Reset(nil) }(fw) acl := NewDefaultManager(fw) @@ -345,13 +344,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(context.Background(), ifaceMock) + fw, err := firewall.NewFirewall(ifaceMock, nil) if err != nil { t.Errorf("create firewall: %v", err) return } defer func(fw manager.Manager) { - _ = fw.Reset() + _ = fw.Reset(nil) }(fw) acl := NewDefaultManager(fw) diff --git a/client/internal/connect.go b/client/internal/connect.go index 13f10fbf1..bcc9d17a3 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error { } // RunWithProbes runs the client's main logic with probes attached -func (c *ConnectClient) RunWithProbes( - probes *ProbeHolder, - runningChan chan error, -) error { +func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error { return c.run(MobileDependency{}, probes, runningChan) } @@ -104,11 +101,7 @@ func (c *ConnectClient) RunOniOS( return c.run(mobileDependency, nil, nil) } -func (c *ConnectClient) run( - mobileDependency MobileDependency, - probes *ProbeHolder, - runningChan chan error, -) error { +func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error { defer func() { if r := recover(); r != nil { log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 772797fac..929e1e60c 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -533,6 +533,13 @@ func (s *DefaultServer) upstreamCallbacks( l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } + // persist dns state right away + ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) + defer cancel() + if err := s.stateManager.PersistState(ctx); err != nil { + l.Errorf("Failed to persist dns state: %v", err) + } + if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() } diff --git a/client/internal/engine.go b/client/internal/engine.go index af2817e6e..190d795cd 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -38,6 +38,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -366,7 +367,7 @@ func (e *Engine) Start() error { return fmt.Errorf("create wg interface: %w", err) } - e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) if err != nil { log.Errorf("failed creating firewall manager: %s", err) } @@ -1167,7 +1168,7 @@ func (e *Engine) close() { } if e.firewall != nil { - err := e.firewall.Reset() + err := e.firewall.Reset(e.stateManager) if err != nil { log.Warnf("failed to reset firewall: %s", err) } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index 65ea0f708..c121b7d77 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -1,6 +1,7 @@ package refcounter import ( + "encoding/json" "errors" "fmt" "runtime" @@ -70,6 +71,19 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key } } +// LoadData loads the data from the existing counter +func (rm *Counter[Key, I, O]) LoadData( + existingCounter *Counter[Key, I, O], +) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + + rm.refCountMap = existingCounter.refCountMap + rm.idMap = existingCounter.idMap +} + // Get retrieves the current reference count and associated data for a key. // If the key doesn't exist, it returns a zero value Ref and false. func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { @@ -201,6 +215,32 @@ func (rm *Counter[Key, I, O]) Clear() { clear(rm.idMap) } +// MarshalJSON implements the json.Marshaler interface for Counter. +func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + RefCountMap map[Key]Ref[O] `json:"refCountMap"` + IDMap map[string][]Key `json:"idMap"` + }{ + RefCountMap: rm.refCountMap, + IDMap: rm.idMap, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for Counter. +func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { + var temp struct { + RefCountMap map[Key]Ref[O] `json:"refCountMap"` + IDMap map[string][]Key `json:"idMap"` + } + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + rm.refCountMap = temp.RefCountMap + rm.idMap = temp.IDMap + + return nil +} + func getCallerInfo(depth int, maxDepth int) (string, bool) { if depth >= maxDepth { return "", false diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 269924677..425908922 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -1,30 +1,15 @@ package systemops import ( - "encoding/json" - "fmt" "net/netip" "sync" - "github.com/hashicorp/go-multierror" - - nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) -type RouteEntry struct { - Prefix netip.Prefix `json:"prefix"` - Nexthop Nexthop `json:"nexthop"` -} - type ShutdownState struct { - Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"` - mu sync.RWMutex -} - -func NewShutdownState() *ShutdownState { - return &ShutdownState{ - Routes: make(map[netip.Prefix]RouteEntry), - } + Counter *ExclusionCounter `json:"counter,omitempty"` + mu sync.RWMutex } func (s *ShutdownState) Name() string { @@ -32,50 +17,16 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.Counter == nil { + return nil + } + sysops := NewSysOps(nil, nil) - var merr *multierror.Error + sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) + sysops.refCounter.LoadData(s.Counter) - s.mu.RLock() - defer s.mu.RUnlock() - - for _, route := range s.Routes { - if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err)) - } - } - - return nberrors.FormatErrorOrNil(merr) -} - -func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) { - s.mu.Lock() - defer s.mu.Unlock() - - s.Routes[prefix] = RouteEntry{ - Prefix: prefix, - Nexthop: nexthop, - } -} - -func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) { - s.mu.Lock() - defer s.mu.Unlock() - - delete(s.Routes, prefix) -} - -// MarshalJSON ensures that empty routes are marshaled as null -func (s *ShutdownState) MarshalJSON() ([]byte, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - if len(s.Routes) == 0 { - return json.Marshal(nil) - } - - return json.Marshal(s.Routes) -} - -func (s *ShutdownState) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &s.Routes) + return sysops.refCounter.Flush() } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 2b8a14ea2..4ff34aa51 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -57,14 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, refcounter.ErrIgnore } - r.updateState(stateManager, prefix, nexthop) + r.updateState(stateManager) return nexthop, err }, func(prefix netip.Prefix, nexthop Nexthop) error { // remove from state even if we have trouble removing it from the route table // it could be already gone - r.removeFromState(stateManager, prefix) + r.updateState(stateManager) return r.removeFromRouteTable(prefix, nexthop) }, @@ -75,24 +75,16 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return r.setupHooks(initAddresses) } -func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) { +func (r *SysOps) updateState(stateManager *statemanager.Manager) { state := getState(stateManager) - state.UpdateRoute(prefix, nexthop) + + state.Counter = r.refCounter if err := stateManager.UpdateState(state); err != nil { log.Errorf("failed to update state: %v", err) } } -func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) { - state := getState(stateManager) - state.RemoveRoute(prefix) - - if err := stateManager.UpdateState(state); err != nil { - log.Errorf("Failed to update state: %v", err) - } -} - func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { if r.refCounter == nil { return nil @@ -107,7 +99,7 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { } if err := stateManager.DeleteState(&ShutdownState{}); err != nil { - log.Errorf("failed to delete state: %v", err) + return fmt.Errorf("delete state: %w", err) } return nil @@ -546,7 +538,7 @@ func getState(stateManager *statemanager.Manager) *ShutdownState { if state := stateManager.GetState(shutdownState); state != nil { shutdownState = state.(*ShutdownState) } else { - shutdownState = NewShutdownState() + shutdownState = &ShutdownState{} } return shutdownState diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go index 64c5316d8..96d6a9f12 100644 --- a/client/internal/statemanager/path.go +++ b/client/internal/statemanager/path.go @@ -5,7 +5,7 @@ import ( "path/filepath" "runtime" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" ) // GetDefaultStatePath returns the path to the state file based on the operating system @@ -27,7 +27,7 @@ func GetDefaultStatePath() string { dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0755); err != nil { - logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) + log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) return "" } diff --git a/client/server/server.go b/client/server/server.go index 342f61b88..a03322081 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -11,7 +11,6 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/durationpb" @@ -21,11 +20,7 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" - nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/auth" - "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/internal" @@ -848,31 +843,3 @@ func sendTerminalNotification() error { return wallCmd.Wait() } - -// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required. -// Otherwise, we might not be able to connect to the management server to retrieve new config. -func restoreResidualState(ctx context.Context) error { - path := statemanager.GetDefaultStatePath() - if path == "" { - return nil - } - - mgr := statemanager.New(path) - - var merr *multierror.Error - - // register the states we are interested in restoring - // this will also allow each subsystem to record its own state - mgr.RegisterState(&dns.ShutdownState{}) - mgr.RegisterState(&systemops.ShutdownState{}) - - if err := mgr.PerformCleanup(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) - } - - if err := mgr.PersistState(ctx); err != nil { - merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) - } - - return nberrors.FormatErrorOrNil(merr) -} diff --git a/client/server/state.go b/client/server/state.go new file mode 100644 index 000000000..509782e86 --- /dev/null +++ b/client/server/state.go @@ -0,0 +1,37 @@ +package server + +import ( + "context" + "fmt" + + "github.com/hashicorp/go-multierror" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required. +// Otherwise, we might not be able to connect to the management server to retrieve new config. +func restoreResidualState(ctx context.Context) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + mgr := statemanager.New(path) + + // register the states we are interested in restoring + registerStates(mgr) + + var merr *multierror.Error + if err := mgr.PerformCleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) + } + + // persist state regardless of cleanup outcome. It could've succeeded partially + if err := mgr.PersistState(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/client/server/state_generic.go b/client/server/state_generic.go new file mode 100644 index 000000000..e6c7bdd44 --- /dev/null +++ b/client/server/state_generic.go @@ -0,0 +1,14 @@ +//go:build !linux || android + +package server + +import ( + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func registerStates(mgr *statemanager.Manager) { + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) +} diff --git a/client/server/state_linux.go b/client/server/state_linux.go new file mode 100644 index 000000000..087628907 --- /dev/null +++ b/client/server/state_linux.go @@ -0,0 +1,18 @@ +//go:build !android + +package server + +import ( + "github.com/netbirdio/netbird/client/firewall/iptables" + "github.com/netbirdio/netbird/client/firewall/nftables" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func registerStates(mgr *statemanager.Manager) { + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) + mgr.RegisterState(&nftables.ShutdownState{}) + mgr.RegisterState(&iptables.ShutdownState{}) +}