[client] Cleanup firewall state on startup (#2768)

This commit is contained in:
Viktor Liu 2024-10-24 14:46:24 +02:00 committed by GitHub
parent 4e918e55ba
commit 8016710d24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 739 additions and 318 deletions

View File

@ -3,7 +3,6 @@
package firewall package firewall
import ( import (
"context"
"fmt" "fmt"
"runtime" "runtime"
@ -11,10 +10,11 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }

View File

@ -3,7 +3,6 @@
package firewall package firewall
import ( import (
"context"
"fmt" "fmt"
"os" "os"
@ -15,6 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@ -32,54 +32,72 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
var fm firewall.Manager fm, errFw := createNativeFirewall(iface)
var errFw error
switch check() { if fm != nil {
case IPTABLES: if err := fm.Init(stateManager); err != nil {
log.Info("creating an iptables firewall manager") log.Errorf("failed to init nftables manager: %s", err)
fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
} }
case NFTABLES:
log.Info("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default:
errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
} }
if iface.IsUserspaceBind() { if iface.IsUserspaceBind() {
var errUsp error return createUserspaceFirewall(iface, fm, errFw)
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
} }
if errFw != nil { return fm, errFw
return nil, errFw }
func createNativeFirewall(iface IFaceMapper) (firewall.Manager, error) {
switch check() {
case IPTABLES:
return createIptablesFirewall(iface)
case NFTABLES:
return createNftablesFirewall(iface)
default:
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, fmt.Errorf("no firewall manager found")
}
}
func createIptablesFirewall(iface IFaceMapper) (firewall.Manager, error) {
log.Info("creating an iptables firewall manager")
fm, err := nbiptables.Create(iface)
if err != nil {
log.Errorf("failed to create iptables manager: %s", err)
}
return fm, err
}
func createNftablesFirewall(iface IFaceMapper) (firewall.Manager, error) {
log.Info("creating an nftables firewall manager")
fm, err := nbnftables.Create(iface)
if err != nil {
log.Errorf("failed to create nftables manager: %s", err)
}
return fm, err
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, errFw error) (firewall.Manager, error) {
var errUsp error
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
} }
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil return fm, nil
} }

View File

@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@ -22,6 +23,8 @@ const (
chainNameOutputRules = "NETBIRD-ACL-OUTPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
) )
type aclEntries map[string][][]string
type entry struct { type entry struct {
spec []string spec []string
position int position int
@ -32,9 +35,11 @@ type aclManager struct {
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routingFwChainName string
entries map[string][][]string entries aclEntries
optionalEntries map[string][]entry optionalEntries map[string][]entry
ipsetStore *ipsetStore ipsetStore *ipsetStore
stateManager *statemanager.Manager
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
@ -48,24 +53,30 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
} }
err := ipset.Init() if err := ipset.Init(); err != nil {
if err != nil { return nil, fmt.Errorf("init ipset: %w", err)
return nil, fmt.Errorf("failed to init ipset: %w", err)
} }
return m, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
m.stateManager = stateManager
m.seedInitialEntries() m.seedInitialEntries()
m.seedInitialOptionalEntries() m.seedInitialOptionalEntries()
err = m.cleanChains() if err := m.cleanChains(); err != nil {
if err != nil { return fmt.Errorf("clean chains: %w", err)
return nil, err
} }
err = m.createDefaultChains() if err := m.createDefaultChains(); err != nil {
if err != nil { return fmt.Errorf("create default chains: %w", err)
return nil, err
} }
return m, nil
m.updateState()
return nil
} }
func (m *aclManager) AddPeerFiltering( func (m *aclManager) AddPeerFiltering(
@ -146,6 +157,8 @@ func (m *aclManager) AddPeerFiltering(
chain: chain, chain: chain,
} }
m.updateState()
return []firewall.Rule{rule}, nil return []firewall.Rule{rule}, nil
} }
@ -180,15 +193,23 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
} }
} }
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
if err != nil { return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
} }
return err
m.updateState()
return nil
} }
func (m *aclManager) Reset() error { func (m *aclManager) Reset() error {
return m.cleanChains() if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
m.updateState()
return nil
} }
// todo write less destructive cleanup mechanism // todo write less destructive cleanup mechanism
@ -348,6 +369,32 @@ func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec) m.entries[chainName] = append(m.entries[chainName], spec)
} }
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs( func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,

View File

@ -8,10 +8,13 @@ import (
"sync" "sync"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Manager of iptables firewall // Manager of iptables firewall
@ -33,10 +36,10 @@ type iFaceMapper interface {
} }
// Create iptables firewall manager // Create iptables firewall manager
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported") return nil, fmt.Errorf("init iptables: %w", err)
} }
m := &Manager{ m := &Manager{
@ -44,20 +47,49 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient, ipv4Client: iptablesClient,
} }
m.router, err = newRouter(context, iptablesClient, wgIface) m.router, err = newRouter(iptablesClient, wgIface)
if err != nil { if err != nil {
log.Debugf("failed to initialize route related chains: %s", err) return nil, fmt.Errorf("create router: %w", err)
return nil, err
} }
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil { if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err) return nil, fmt.Errorf("create acl manager: %w", err)
return nil, err
} }
return m, nil return m, nil
} }
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}
stateManager.RegisterState(state)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
if err := m.router.init(stateManager); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
// persist early to ensure cleanup of chains
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil
}
// AddPeerFiltering adds a rule to the firewall // AddPeerFiltering adds a rule to the firewall
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
@ -133,20 +165,27 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
errAcl := m.aclMgr.Reset() var merr *multierror.Error
if errAcl != nil {
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl) if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
} }
errMgr := m.router.Reset() if err := m.router.Reset(); err != nil {
if errMgr != nil { merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
return errMgr
} }
return errAcl
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic

View File

@ -1,7 +1,6 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"testing" "testing"
@ -56,13 +55,14 @@ func TestIptablesManager(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset() err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@ -122,7 +122,7 @@ func TestIptablesManager(t *testing.T) {
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
@ -154,13 +154,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset() err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@ -219,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
}) })
} }
@ -251,12 +252,13 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset() err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@ -3,7 +3,6 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net/netip" "net/netip"
"strconv" "strconv"
@ -18,6 +17,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@ -48,28 +48,31 @@ type routeFilteringRuleParams struct {
SetName string SetName string
} }
type routeRules map[string][]string
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct { type router struct {
ctx context.Context
stop context.CancelFunc
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
rules map[string][]string rules routeRules
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}] ipsetCounter *ipsetCounter
wgIface iFaceMapper wgIface iFaceMapper
legacyManagement bool legacyManagement bool
stateManager *statemanager.Manager
} }
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{ r := &router{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
r.createIpSet, func(name string, sources []netip.Prefix) (struct{}, error) {
return struct{}{}, r.createIpSet(name, sources)
},
func(name string, _ struct{}) error { func(name string, _ struct{}) error {
return r.deleteIpSet(name) return r.deleteIpSet(name)
}, },
@ -79,16 +82,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI
return nil, fmt.Errorf("init ipset: %w", err) return nil, fmt.Errorf("init ipset: %w", err)
} }
err := r.cleanUpDefaultForwardRules() return r, nil
if err != nil { }
log.Errorf("cleanup routing rules: %s", err)
return nil, err func (r *router) init(stateManager *statemanager.Manager) error {
r.stateManager = stateManager
if err := r.cleanUpDefaultForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
} }
err = r.createContainers()
if err != nil { if err := r.createContainers(); err != nil {
log.Errorf("create containers for route: %s", err) return fmt.Errorf("create containers: %w", err)
} }
return r, err
r.updateState()
return nil
} }
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
@ -129,6 +139,8 @@ func (r *router) AddRouteFiltering(
r.rules[string(ruleKey)] = rule r.rules[string(ruleKey)] = rule
r.updateState()
return ruleKey, nil return ruleKey, nil
} }
@ -152,6 +164,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
log.Debugf("route rule %s not found", ruleKey) log.Debugf("route rule %s not found", ruleKey)
} }
r.updateState()
return nil return nil
} }
@ -164,18 +178,18 @@ func (r *router) findSetNameInRule(rule []string) string {
return "" return ""
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) { func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err) return fmt.Errorf("create set %s: %w", setName, err)
} }
for _, prefix := range sources { for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil { if err := ipset.AddPrefix(setName, prefix); err != nil {
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err) return fmt.Errorf("add element to set %s: %w", setName, err)
} }
} }
return struct{}{}, nil return nil
} }
func (r *router) deleteIpSet(setName string) error { func (r *router) deleteIpSet(setName string) error {
@ -206,6 +220,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("add inverse nat rule: %w", err) return fmt.Errorf("add inverse nat rule: %w", err)
} }
r.updateState()
return nil return nil
} }
@ -223,6 +239,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy routing rule: %w", err) return fmt.Errorf("remove legacy routing rule: %w", err)
} }
r.updateState()
return nil return nil
} }
@ -280,6 +298,9 @@ func (r *router) RemoveAllLegacyRouteRules() error {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} }
} }
r.updateState()
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
@ -294,6 +315,8 @@ func (r *router) Reset() error {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
r.updateState()
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
@ -431,6 +454,32 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
return nil return nil
} }
func (r *router) updateState() {
if r.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := r.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i" intdir := "-i"
lointdir := "-o" lointdir := "-o"

View File

@ -3,7 +3,6 @@
package iptables package iptables
import ( import (
"context"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
@ -30,8 +29,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() _ = manager.Reset()
@ -74,8 +74,9 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
err := manager.Reset() err := manager.Reset()
@ -132,8 +133,9 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() _ = manager.Reset()
}() }()
@ -183,8 +185,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client") require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(context.Background(), iptablesClient, ifaceMock) r, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "Failed to create router manager") require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))
defer func() { defer func() {
err := r.Reset() err := r.Reset()

View File

@ -1,14 +1,16 @@
package iptables package iptables
import "encoding/json"
type ipList struct { type ipList struct {
ips map[string]struct{} ips map[string]struct{}
} }
func newIpList(ip string) ipList { func newIpList(ip string) *ipList {
ips := make(map[string]struct{}) ips := make(map[string]struct{})
ips[ip] = struct{}{} ips[ip] = struct{}{}
return ipList{ return &ipList{
ips: ips, ips: ips,
} }
} }
@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{} s.ips[ip] = struct{}{}
} }
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
return nil
}
type ipsetStore struct { type ipsetStore struct {
ipsets map[string]ipList // ipsetName -> ruleset ipsets map[string]*ipList
} }
func newIpsetStore() *ipsetStore { func newIpsetStore() *ipsetStore {
return &ipsetStore{ return &ipsetStore{
ipsets: make(map[string]ipList), ipsets: make(map[string]*ipList),
} }
} }
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) { func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName] r, ok := s.ipsets[ipsetName]
return r, ok return r, ok
} }
func (s *ipsetStore) addIpList(ipsetName string, list ipList) { func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
s.ipsets[ipsetName] = list s.ipsets[ipsetName] = list
} }
func (s *ipsetStore) deleteIpset(ipsetName string) { func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName) delete(s.ipsets, ipsetName)
} }
@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string {
} }
return names return names
} }
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
return nil
}

View File

@ -0,0 +1,70 @@
package iptables
import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
}
func (s *ShutdownState) Name() string {
return "iptables_state"
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}
if s.RouteRules != nil {
ipt.router.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
if err := ipt.Reset(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}
return nil
}

View File

@ -10,6 +10,8 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@ -52,6 +54,8 @@ const (
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality // Netbird client for ACL and routing functionality
type Manager interface { type Manager interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
AllowNetbird() error AllowNetbird() error
@ -91,7 +95,7 @@ type Manager interface {
SetLegacyManagement(legacy bool) error SetLegacyManagement(legacy bool) error
// Reset firewall to the default state // Reset firewall to the default state
Reset() error Reset(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error

View File

@ -17,7 +17,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@ -56,13 +55,6 @@ type AclManager struct {
rules map[string]*Rule rules map[string]*Rule
} }
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them // sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation) // it's differ then rConn (which does create new conn for each flush operation)
@ -70,10 +62,10 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
// overloads netlink with high amount of rules ( > 10000) // overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting()) sConn, err := nftables.New(nftables.AsLasting())
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("create nf conn: %w", err)
} }
m := &AclManager{ return &AclManager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
sConn: sConn, sConn: sConn,
wgIface: wgIface, wgIface: wgIface,
@ -82,14 +74,12 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule), rules: make(map[string]*Rule),
} }, nil
}
err = m.createDefaultChains() func (m *AclManager) init(workTable *nftables.Table) error {
if err != nil { m.workTable = workTable
return nil, err return m.createDefaultChains()
}
return m, nil
} }
// AddPeerFiltering rule to the firewall // AddPeerFiltering rule to the firewall

View File

@ -14,6 +14,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@ -24,6 +26,13 @@ const (
chainNameInput = "INPUT" chainNameInput = "INPUT"
) )
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
@ -35,30 +44,68 @@ type Manager struct {
} }
// Create nftables firewall manager // Create nftables firewall manager
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
wgIface: wgIface, wgIface: wgIface,
} }
workTable, err := m.createWorkTable() workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
if err != nil {
return nil, err
}
m.router, err = newRouter(context, workTable, wgIface) var err error
m.router, err = newRouter(workTable, wgIface)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("create router: %w", err)
} }
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("create acl manager: %w", err)
} }
return m, nil return m, nil
} }
// Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error {
workTable, err := m.createWorkTable()
if err != nil {
return fmt.Errorf("create work table: %w", err)
}
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
}
// persist early
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil
}
// AddPeerFiltering rule to the firewall // AddPeerFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
@ -203,48 +250,80 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
chains, err := m.rConn.ListChains() if err := m.resetNetbirdInputRules(); err != nil {
if err != nil { return fmt.Errorf("reset netbird input rules: %v", err)
return fmt.Errorf("list of chains: %w", err)
} }
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err)
}
return nil
}
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains { for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" { if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c) rules, err := m.rConn.GetRules(c.Table, c)
if err != nil { if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err) log.Errorf("get rules for chain %q: %v", c.Name, err)
continue continue
} }
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { m.deleteMatchingRules(rules)
if err := m.rConn.DelRule(r); err != nil { }
log.Errorf("delete rule: %v", err) }
} }
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
} }
} }
} }
}
if err := m.router.Reset(); err != nil { func (m *Manager) cleanupNetbirdTables() error {
return fmt.Errorf("reset forward rules: %v", err)
}
tables, err := m.rConn.ListTables() tables, err := m.rConn.ListTables()
if err != nil { if err != nil {
return fmt.Errorf("list of tables: %w", err) return fmt.Errorf("list tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableNameNetbird {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
return nil
return m.rConn.Flush()
} }
// Flush rule/chain/set operations from the buffer // Flush rule/chain/set operations from the buffer

View File

@ -1,7 +1,6 @@
package nftables package nftables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@ -58,12 +57,13 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) { func TestNftablesManager(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@ -169,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
// established rule remains // established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion") require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
@ -192,12 +192,13 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(); err != nil { if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@ -2,7 +2,6 @@ package nftables
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -40,8 +39,6 @@ var (
) )
type router struct { type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn conn *nftables.Conn
workTable *nftables.Table workTable *nftables.Table
filterTable *nftables.Table filterTable *nftables.Table
@ -54,12 +51,8 @@ type router struct {
legacyManagement bool legacyManagement bool
} }
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{ r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{}, conn: &nftables.Conn{},
workTable: workTable, workTable: workTable,
chains: make(map[string]*nftables.Chain), chains: make(map[string]*nftables.Chain),
@ -78,20 +71,25 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
if errors.Is(err, errFilterTableNotFound) { if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules") log.Warnf("table 'filter' not found for forward rules")
} else { } else {
return nil, err return nil, fmt.Errorf("load filter table: %w", err)
} }
} }
err = r.removeAcceptForwardRules() return r, nil
if err != nil { }
func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err) log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
} }
err = r.createContainers() if err := r.createContainers(); err != nil {
if err != nil { return fmt.Errorf("create containers: %w", err)
log.Errorf("failed to create containers for route: %s", err)
} }
return r, err
return nil
} }
// Reset cleans existing nftables default forward rules from the system // Reset cleans existing nftables default forward rules from the system

View File

@ -3,7 +3,6 @@
package nftables package nftables
import ( import (
"context"
"encoding/binary" "encoding/binary"
"net/netip" "net/netip"
"os/exec" "os/exec"
@ -40,8 +39,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases { for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table, ifaceMock) manager, err := newRouter(table, ifaceMock)
require.NoError(t, err, "failed to create router") require.NoError(t, err, "failed to create router")
require.NoError(t, manager.init(table))
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
@ -142,8 +142,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
for _, testCase := range test.RemoveRuleTestCases { for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table, ifaceMock) manager, err := newRouter(table, ifaceMock)
require.NoError(t, err, "failed to create router") require.NoError(t, err, "failed to create router")
require.NoError(t, manager.init(table))
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
@ -210,8 +211,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable() defer deleteWorkTable()
r, err := newRouter(context.Background(), workTable, ifaceMock) r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router") require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func(r *router) { defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules") require.NoError(t, r.Reset(), "Failed to reset rules")
@ -376,8 +378,9 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable() defer deleteWorkTable()
r, err := newRouter(context.Background(), workTable, ifaceMock) r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router") require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() { defer func() {
require.NoError(t, r.Reset(), "Failed to reset router") require.NoError(t, r.Reset(), "Failed to reset router")

View File

@ -0,0 +1,47 @@
package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}
func (s *ShutdownState) Name() string {
return "nftables_state"
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}
if err := nft.Reset(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err)
}
return nil
}

View File

@ -2,8 +2,10 @@
package uspfilter package uspfilter
import "github.com/netbirdio/netbird/client/internal/statemanager"
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@ -11,7 +13,7 @@ func (m *Manager) Reset() error {
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet)
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Reset() return m.nativeFirewall.Reset(stateManager)
} }
return nil return nil
} }

View File

@ -6,6 +6,8 @@ import (
"syscall" "syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
type action string type action string
@ -17,7 +19,7 @@ const (
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@ -14,6 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const layerTypeAll = 0 const layerTypeAll = 0
@ -97,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return false return false
@ -190,7 +195,7 @@ func (m *Manager) AddPeerFiltering(
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return nil, errRouteNotSupported return nil, errRouteNotSupported
} }

View File

@ -259,7 +259,7 @@ func TestManagerReset(t *testing.T) {
return return
} }
err = m.Reset() err = m.Reset(nil)
if err != nil { if err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
@ -330,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if err = m.Reset(); err != nil { if err = m.Reset(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
} }
@ -396,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(); err != nil { if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@ -1,7 +1,6 @@
package acl package acl
import ( import (
"context"
"net" "net"
"testing" "testing"
@ -52,13 +51,13 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes() }).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(context.Background(), ifaceMock) fw, err := firewall.NewFirewall(ifaceMock, nil)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset() _ = fw.Reset(nil)
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
@ -345,13 +344,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes() }).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(context.Background(), ifaceMock) fw, err := firewall.NewFirewall(ifaceMock, nil)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset() _ = fw.Reset(nil)
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)

View File

@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error {
} }
// RunWithProbes runs the client's main logic with probes attached // RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes( func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
probes *ProbeHolder,
runningChan chan error,
) error {
return c.run(MobileDependency{}, probes, runningChan) return c.run(MobileDependency{}, probes, runningChan)
} }
@ -104,11 +101,7 @@ func (c *ConnectClient) RunOniOS(
return c.run(mobileDependency, nil, nil) return c.run(mobileDependency, nil, nil)
} }
func (c *ConnectClient) run( func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
mobileDependency MobileDependency,
probes *ProbeHolder,
runningChan chan error,
) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))

View File

@ -533,6 +533,13 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
} }
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone() s.addHostRootZone()
} }

View File

@ -38,6 +38,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@ -366,7 +367,7 @@ func (e *Engine) Start() error {
return fmt.Errorf("create wg interface: %w", err) return fmt.Errorf("create wg interface: %w", err)
} }
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
if err != nil { if err != nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
} }
@ -1167,7 +1168,7 @@ func (e *Engine) close() {
} }
if e.firewall != nil { if e.firewall != nil {
err := e.firewall.Reset() err := e.firewall.Reset(e.stateManager)
if err != nil { if err != nil {
log.Warnf("failed to reset firewall: %s", err) log.Warnf("failed to reset firewall: %s", err)
} }

View File

@ -1,6 +1,7 @@
package refcounter package refcounter
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"runtime" "runtime"
@ -70,6 +71,19 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
} }
} }
// LoadData loads the data from the existing counter
func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O],
) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap
}
// Get retrieves the current reference count and associated data for a key. // Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false. // If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
@ -201,6 +215,32 @@ func (rm *Counter[Key, I, O]) Clear() {
clear(rm.idMap) clear(rm.idMap)
} }
// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`
}{
RefCountMap: rm.refCountMap,
IDMap: rm.idMap,
})
}
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
var temp struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
rm.refCountMap = temp.RefCountMap
rm.idMap = temp.IDMap
return nil
}
func getCallerInfo(depth int, maxDepth int) (string, bool) { func getCallerInfo(depth int, maxDepth int) (string, bool) {
if depth >= maxDepth { if depth >= maxDepth {
return "", false return "", false

View File

@ -1,30 +1,15 @@
package systemops package systemops
import ( import (
"encoding/json"
"fmt"
"net/netip" "net/netip"
"sync" "sync"
"github.com/hashicorp/go-multierror" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nberrors "github.com/netbirdio/netbird/client/errors"
) )
type RouteEntry struct {
Prefix netip.Prefix `json:"prefix"`
Nexthop Nexthop `json:"nexthop"`
}
type ShutdownState struct { type ShutdownState struct {
Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"` Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex mu sync.RWMutex
}
func NewShutdownState() *ShutdownState {
return &ShutdownState{
Routes: make(map[netip.Prefix]RouteEntry),
}
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@ -32,50 +17,16 @@ func (s *ShutdownState) Name() string {
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()
if s.Counter == nil {
return nil
}
sysops := NewSysOps(nil, nil) sysops := NewSysOps(nil, nil)
var merr *multierror.Error sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData(s.Counter)
s.mu.RLock() return sysops.refCounter.Flush()
defer s.mu.RUnlock()
for _, route := range s.Routes {
if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) {
s.mu.Lock()
defer s.mu.Unlock()
s.Routes[prefix] = RouteEntry{
Prefix: prefix,
Nexthop: nexthop,
}
}
func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.Routes, prefix)
}
// MarshalJSON ensures that empty routes are marshaled as null
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.Routes) == 0 {
return json.Marshal(nil)
}
return json.Marshal(s.Routes)
}
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &s.Routes)
} }

View File

@ -57,14 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, refcounter.ErrIgnore return nexthop, refcounter.ErrIgnore
} }
r.updateState(stateManager, prefix, nexthop) r.updateState(stateManager)
return nexthop, err return nexthop, err
}, },
func(prefix netip.Prefix, nexthop Nexthop) error { func(prefix netip.Prefix, nexthop Nexthop) error {
// remove from state even if we have trouble removing it from the route table // remove from state even if we have trouble removing it from the route table
// it could be already gone // it could be already gone
r.removeFromState(stateManager, prefix) r.updateState(stateManager)
return r.removeFromRouteTable(prefix, nexthop) return r.removeFromRouteTable(prefix, nexthop)
}, },
@ -75,24 +75,16 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return r.setupHooks(initAddresses) return r.setupHooks(initAddresses)
} }
func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) { func (r *SysOps) updateState(stateManager *statemanager.Manager) {
state := getState(stateManager) state := getState(stateManager)
state.UpdateRoute(prefix, nexthop)
state.Counter = r.refCounter
if err := stateManager.UpdateState(state); err != nil { if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
} }
func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) {
state := getState(stateManager)
state.RemoveRoute(prefix)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("Failed to update state: %v", err)
}
}
func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
if r.refCounter == nil { if r.refCounter == nil {
return nil return nil
@ -107,7 +99,7 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
} }
if err := stateManager.DeleteState(&ShutdownState{}); err != nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
log.Errorf("failed to delete state: %v", err) return fmt.Errorf("delete state: %w", err)
} }
return nil return nil
@ -546,7 +538,7 @@ func getState(stateManager *statemanager.Manager) *ShutdownState {
if state := stateManager.GetState(shutdownState); state != nil { if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState) shutdownState = state.(*ShutdownState)
} else { } else {
shutdownState = NewShutdownState() shutdownState = &ShutdownState{}
} }
return shutdownState return shutdownState

View File

@ -5,7 +5,7 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// GetDefaultStatePath returns the path to the state file based on the operating system // GetDefaultStatePath returns the path to the state file based on the operating system
@ -27,7 +27,7 @@ func GetDefaultStatePath() string {
dir := filepath.Dir(path) dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0755); err != nil {
logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
return "" return ""
} }

View File

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
@ -21,11 +20,7 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
@ -848,31 +843,3 @@ func sendTerminalNotification() error {
return wallCmd.Wait() return wallCmd.Wait()
} }
// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
func restoreResidualState(ctx context.Context) error {
path := statemanager.GetDefaultStatePath()
if path == "" {
return nil
}
mgr := statemanager.New(path)
var merr *multierror.Error
// register the states we are interested in restoring
// this will also allow each subsystem to record its own state
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})
if err := mgr.PerformCleanup(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err))
}
if err := mgr.PersistState(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}

37
client/server/state.go Normal file
View File

@ -0,0 +1,37 @@
package server
import (
"context"
"fmt"
"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
func restoreResidualState(ctx context.Context) error {
path := statemanager.GetDefaultStatePath()
if path == "" {
return nil
}
mgr := statemanager.New(path)
// register the states we are interested in restoring
registerStates(mgr)
var merr *multierror.Error
if err := mgr.PerformCleanup(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err))
}
// persist state regardless of cleanup outcome. It could've succeeded partially
if err := mgr.PersistState(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@ -0,0 +1,14 @@
//go:build !linux || android
package server
import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})
}

View File

@ -0,0 +1,18 @@
//go:build !android
package server
import (
"github.com/netbirdio/netbird/client/firewall/iptables"
"github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})
mgr.RegisterState(&nftables.ShutdownState{})
mgr.RegisterState(&iptables.ShutdownState{})
}