mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-16 10:20:09 +01:00
[client] Cleanup firewall state on startup (#2768)
This commit is contained in:
parent
4e918e55ba
commit
8016710d24
@ -3,7 +3,6 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
@ -11,10 +10,11 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
@ -15,6 +14,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -32,54 +32,72 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
// for the userspace packet filtering firewall
|
// for the userspace packet filtering firewall
|
||||||
var fm firewall.Manager
|
fm, errFw := createNativeFirewall(iface)
|
||||||
var errFw error
|
|
||||||
|
|
||||||
switch check() {
|
if fm != nil {
|
||||||
case IPTABLES:
|
if err := fm.Init(stateManager); err != nil {
|
||||||
log.Info("creating an iptables firewall manager")
|
log.Errorf("failed to init nftables manager: %s", err)
|
||||||
fm, errFw = nbiptables.Create(context, iface)
|
|
||||||
if errFw != nil {
|
|
||||||
log.Errorf("failed to create iptables manager: %s", errFw)
|
|
||||||
}
|
}
|
||||||
case NFTABLES:
|
|
||||||
log.Info("creating an nftables firewall manager")
|
|
||||||
fm, errFw = nbnftables.Create(context, iface)
|
|
||||||
if errFw != nil {
|
|
||||||
log.Errorf("failed to create nftables manager: %s", errFw)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
errFw = fmt.Errorf("no firewall manager found")
|
|
||||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if iface.IsUserspaceBind() {
|
if iface.IsUserspaceBind() {
|
||||||
var errUsp error
|
return createUserspaceFirewall(iface, fm, errFw)
|
||||||
if errFw == nil {
|
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
|
||||||
} else {
|
|
||||||
fm, errUsp = uspfilter.Create(iface)
|
|
||||||
}
|
|
||||||
if errUsp != nil {
|
|
||||||
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
|
||||||
return nil, errUsp
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := fm.AllowNetbird(); err != nil {
|
|
||||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
|
||||||
}
|
|
||||||
return fm, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if errFw != nil {
|
return fm, errFw
|
||||||
return nil, errFw
|
}
|
||||||
|
|
||||||
|
func createNativeFirewall(iface IFaceMapper) (firewall.Manager, error) {
|
||||||
|
switch check() {
|
||||||
|
case IPTABLES:
|
||||||
|
return createIptablesFirewall(iface)
|
||||||
|
case NFTABLES:
|
||||||
|
return createNftablesFirewall(iface)
|
||||||
|
default:
|
||||||
|
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||||
|
return nil, fmt.Errorf("no firewall manager found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createIptablesFirewall(iface IFaceMapper) (firewall.Manager, error) {
|
||||||
|
log.Info("creating an iptables firewall manager")
|
||||||
|
fm, err := nbiptables.Create(iface)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create iptables manager: %s", err)
|
||||||
|
}
|
||||||
|
return fm, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNftablesFirewall(iface IFaceMapper) (firewall.Manager, error) {
|
||||||
|
log.Info("creating an nftables firewall manager")
|
||||||
|
fm, err := nbnftables.Create(iface)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create nftables manager: %s", err)
|
||||||
|
}
|
||||||
|
return fm, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, errFw error) (firewall.Manager, error) {
|
||||||
|
var errUsp error
|
||||||
|
if errFw == nil {
|
||||||
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||||
|
} else {
|
||||||
|
fm, errUsp = uspfilter.Create(iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if errUsp != nil {
|
||||||
|
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
|
||||||
|
return nil, errUsp
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fm.AllowNetbird(); err != nil {
|
||||||
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
return fm, nil
|
return fm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,6 +23,8 @@ const (
|
|||||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type aclEntries map[string][][]string
|
||||||
|
|
||||||
type entry struct {
|
type entry struct {
|
||||||
spec []string
|
spec []string
|
||||||
position int
|
position int
|
||||||
@ -32,9 +35,11 @@ type aclManager struct {
|
|||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routingFwChainName string
|
routingFwChainName string
|
||||||
|
|
||||||
entries map[string][][]string
|
entries aclEntries
|
||||||
optionalEntries map[string][]entry
|
optionalEntries map[string][]entry
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
|
|
||||||
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||||
@ -48,24 +53,30 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
|
|||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := ipset.Init()
|
if err := ipset.Init(); err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
return nil, fmt.Errorf("failed to init ipset: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||||
|
m.stateManager = stateManager
|
||||||
|
|
||||||
m.seedInitialEntries()
|
m.seedInitialEntries()
|
||||||
m.seedInitialOptionalEntries()
|
m.seedInitialOptionalEntries()
|
||||||
|
|
||||||
err = m.cleanChains()
|
if err := m.cleanChains(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("clean chains: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.createDefaultChains()
|
if err := m.createDefaultChains(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("create default chains: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
return m, nil
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) AddPeerFiltering(
|
func (m *aclManager) AddPeerFiltering(
|
||||||
@ -146,6 +157,8 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
chain: chain,
|
chain: chain,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
return []firewall.Rule{rule}, nil
|
return []firewall.Rule{rule}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -180,15 +193,23 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
|
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||||
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
|
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) Reset() error {
|
func (m *aclManager) Reset() error {
|
||||||
return m.cleanChains()
|
if err := m.cleanChains(); err != nil {
|
||||||
|
return fmt.Errorf("clean chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo write less destructive cleanup mechanism
|
// todo write less destructive cleanup mechanism
|
||||||
@ -348,6 +369,32 @@ func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
|||||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) updateState() {
|
||||||
|
if m.stateManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentState *ShutdownState
|
||||||
|
if existing := m.stateManager.GetState(currentState); existing != nil {
|
||||||
|
if existingState, ok := existing.(*ShutdownState); ok {
|
||||||
|
currentState = existingState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if currentState == nil {
|
||||||
|
currentState = &ShutdownState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState.Lock()
|
||||||
|
defer currentState.Unlock()
|
||||||
|
|
||||||
|
currentState.ACLEntries = m.entries
|
||||||
|
currentState.ACLIPsetStore = m.ipsetStore
|
||||||
|
|
||||||
|
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
func filterRuleSpecs(
|
func filterRuleSpecs(
|
||||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
||||||
|
@ -8,10 +8,13 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
@ -33,10 +36,10 @@ type iFaceMapper interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
return nil, fmt.Errorf("init iptables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@ -44,20 +47,49 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
|||||||
ipv4Client: iptablesClient,
|
ipv4Client: iptablesClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.router, err = newRouter(context, iptablesClient, wgIface)
|
m.router, err = newRouter(iptablesClient, wgIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to initialize route related chains: %s", err)
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to initialize ACL manager: %s", err)
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||||
|
state := &ShutdownState{
|
||||||
|
InterfaceState: &InterfaceState{
|
||||||
|
NameStr: m.wgIface.Name(),
|
||||||
|
WGAddress: m.wgIface.Address(),
|
||||||
|
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
stateManager.RegisterState(state)
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router.init(stateManager); err != nil {
|
||||||
|
return fmt.Errorf("router init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.aclMgr.init(stateManager); err != nil {
|
||||||
|
// TODO: cleanup router
|
||||||
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// persist early to ensure cleanup of chains
|
||||||
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddPeerFiltering adds a rule to the firewall
|
// AddPeerFiltering adds a rule to the firewall
|
||||||
//
|
//
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
@ -133,20 +165,27 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
errAcl := m.aclMgr.Reset()
|
var merr *multierror.Error
|
||||||
if errAcl != nil {
|
|
||||||
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
|
if err := m.aclMgr.Reset(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||||
}
|
}
|
||||||
errMgr := m.router.Reset()
|
if err := m.router.Reset(); err != nil {
|
||||||
if errMgr != nil {
|
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||||
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
|
|
||||||
return errMgr
|
|
||||||
}
|
}
|
||||||
return errAcl
|
|
||||||
|
// attempt to delete state only if all other operations succeeded
|
||||||
|
if merr == nil {
|
||||||
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
@ -56,13 +55,14 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset()
|
err := manager.Reset(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@ -122,7 +122,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
|
|
||||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||||
@ -154,13 +154,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset()
|
err := manager.Reset(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@ -219,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
err = manager.Reset()
|
err = manager.Reset(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -251,12 +252,13 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset()
|
err := manager.Reset(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -18,6 +17,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -48,28 +48,31 @@ type routeFilteringRuleParams struct {
|
|||||||
SetName string
|
SetName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type routeRules map[string][]string
|
||||||
|
|
||||||
|
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
iptablesClient *iptables.IPTables
|
iptablesClient *iptables.IPTables
|
||||||
rules map[string][]string
|
rules routeRules
|
||||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
|
ipsetCounter *ipsetCounter
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
|
|
||||||
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
r := &router{
|
r := &router{
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
rules: make(map[string][]string),
|
rules: make(map[string][]string),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
r.ipsetCounter = refcounter.New(
|
||||||
r.createIpSet,
|
func(name string, sources []netip.Prefix) (struct{}, error) {
|
||||||
|
return struct{}{}, r.createIpSet(name, sources)
|
||||||
|
},
|
||||||
func(name string, _ struct{}) error {
|
func(name string, _ struct{}) error {
|
||||||
return r.deleteIpSet(name)
|
return r.deleteIpSet(name)
|
||||||
},
|
},
|
||||||
@ -79,16 +82,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI
|
|||||||
return nil, fmt.Errorf("init ipset: %w", err)
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := r.cleanUpDefaultForwardRules()
|
return r, nil
|
||||||
if err != nil {
|
}
|
||||||
log.Errorf("cleanup routing rules: %s", err)
|
|
||||||
return nil, err
|
func (r *router) init(stateManager *statemanager.Manager) error {
|
||||||
|
r.stateManager = stateManager
|
||||||
|
|
||||||
|
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||||
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
}
|
}
|
||||||
err = r.createContainers()
|
|
||||||
if err != nil {
|
if err := r.createContainers(); err != nil {
|
||||||
log.Errorf("create containers for route: %s", err)
|
return fmt.Errorf("create containers: %w", err)
|
||||||
}
|
}
|
||||||
return r, err
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
@ -129,6 +139,8 @@ func (r *router) AddRouteFiltering(
|
|||||||
|
|
||||||
r.rules[string(ruleKey)] = rule
|
r.rules[string(ruleKey)] = rule
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,6 +164,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
log.Debugf("route rule %s not found", ruleKey)
|
log.Debugf("route rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,18 +178,18 @@ func (r *router) findSetNameInRule(rule []string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||||
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
|
return fmt.Errorf("create set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, prefix := range sources {
|
for _, prefix := range sources {
|
||||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||||
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
|
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return struct{}{}, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteIpSet(setName string) error {
|
func (r *router) deleteIpSet(setName string) error {
|
||||||
@ -206,6 +220,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -223,6 +239,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -280,6 +298,9 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,6 +315,8 @@ func (r *router) Reset() error {
|
|||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -431,6 +454,32 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) updateState() {
|
||||||
|
if r.stateManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentState *ShutdownState
|
||||||
|
if existing := r.stateManager.GetState(currentState); existing != nil {
|
||||||
|
if existingState, ok := existing.(*ShutdownState); ok {
|
||||||
|
currentState = existingState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if currentState == nil {
|
||||||
|
currentState = &ShutdownState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState.Lock()
|
||||||
|
defer currentState.Unlock()
|
||||||
|
|
||||||
|
currentState.RouteRules = r.rules
|
||||||
|
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||||
|
|
||||||
|
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||||
intdir := "-i"
|
intdir := "-i"
|
||||||
lointdir := "-o"
|
lointdir := "-o"
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
@ -30,8 +29,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "should return a valid iptables manager")
|
require.NoError(t, err, "should return a valid iptables manager")
|
||||||
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = manager.Reset()
|
_ = manager.Reset()
|
||||||
@ -74,8 +74,9 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset()
|
err := manager.Reset()
|
||||||
@ -132,8 +133,9 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
|
||||||
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
|
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
require.NoError(t, manager.init(nil))
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = manager.Reset()
|
_ = manager.Reset()
|
||||||
}()
|
}()
|
||||||
@ -183,8 +185,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "Failed to create iptables client")
|
require.NoError(t, err, "Failed to create iptables client")
|
||||||
|
|
||||||
r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
|
r, err := newRouter(iptablesClient, ifaceMock)
|
||||||
require.NoError(t, err, "Failed to create router manager")
|
require.NoError(t, err, "Failed to create router manager")
|
||||||
|
require.NoError(t, r.init(nil))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := r.Reset()
|
err := r.Reset()
|
||||||
|
@ -1,14 +1,16 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
type ipList struct {
|
type ipList struct {
|
||||||
ips map[string]struct{}
|
ips map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIpList(ip string) ipList {
|
func newIpList(ip string) *ipList {
|
||||||
ips := make(map[string]struct{})
|
ips := make(map[string]struct{})
|
||||||
ips[ip] = struct{}{}
|
ips[ip] = struct{}{}
|
||||||
|
|
||||||
return ipList{
|
return &ipList{
|
||||||
ips: ips,
|
ips: ips,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) {
|
|||||||
s.ips[ip] = struct{}{}
|
s.ips[ip] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler
|
||||||
|
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
IPs map[string]struct{} `json:"ips"`
|
||||||
|
}{
|
||||||
|
IPs: s.ips,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler
|
||||||
|
func (s *ipList) UnmarshalJSON(data []byte) error {
|
||||||
|
temp := struct {
|
||||||
|
IPs map[string]struct{} `json:"ips"`
|
||||||
|
}{}
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.ips = temp.IPs
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type ipsetStore struct {
|
type ipsetStore struct {
|
||||||
ipsets map[string]ipList // ipsetName -> ruleset
|
ipsets map[string]*ipList
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIpsetStore() *ipsetStore {
|
func newIpsetStore() *ipsetStore {
|
||||||
return &ipsetStore{
|
return &ipsetStore{
|
||||||
ipsets: make(map[string]ipList),
|
ipsets: make(map[string]*ipList),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
|
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||||
r, ok := s.ipsets[ipsetName]
|
r, ok := s.ipsets[ipsetName]
|
||||||
return r, ok
|
return r, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
|
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
|
||||||
s.ipsets[ipsetName] = list
|
s.ipsets[ipsetName] = list
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||||
s.ipsets[ipsetName] = ipList{}
|
|
||||||
delete(s.ipsets, ipsetName)
|
delete(s.ipsets, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string {
|
|||||||
}
|
}
|
||||||
return names
|
return names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler
|
||||||
|
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
IPSets map[string]*ipList `json:"ipsets"`
|
||||||
|
}{
|
||||||
|
IPSets: s.ipsets,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler
|
||||||
|
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
||||||
|
temp := struct {
|
||||||
|
IPSets map[string]*ipList `json:"ipsets"`
|
||||||
|
}{}
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.ipsets = temp.IPSets
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
70
client/firewall/iptables/state_linux.go
Normal file
70
client/firewall/iptables/state_linux.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InterfaceState struct {
|
||||||
|
NameStr string `json:"name"`
|
||||||
|
WGAddress iface.WGAddress `json:"wg_address"`
|
||||||
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) Name() string {
|
||||||
|
return i.NameStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) Address() device.WGAddress {
|
||||||
|
return i.WGAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||||
|
return i.UserspaceBind
|
||||||
|
}
|
||||||
|
|
||||||
|
type ShutdownState struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
|
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||||
|
|
||||||
|
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||||
|
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||||
|
|
||||||
|
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||||
|
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "iptables_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
ipt, err := Create(s.InterfaceState)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create iptables manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.RouteRules != nil {
|
||||||
|
ipt.router.rules = s.RouteRules
|
||||||
|
}
|
||||||
|
if s.RouteIPsetCounter != nil {
|
||||||
|
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.ACLEntries != nil {
|
||||||
|
ipt.aclMgr.entries = s.ACLEntries
|
||||||
|
}
|
||||||
|
if s.ACLIPsetStore != nil {
|
||||||
|
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipt.Reset(nil); err != nil {
|
||||||
|
return fmt.Errorf("reset iptables manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -10,6 +10,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -52,6 +54,8 @@ const (
|
|||||||
// It declares methods which handle actions required by the
|
// It declares methods which handle actions required by the
|
||||||
// Netbird client for ACL and routing functionality
|
// Netbird client for ACL and routing functionality
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
|
Init(stateManager *statemanager.Manager) error
|
||||||
|
|
||||||
// AllowNetbird allows netbird interface traffic
|
// AllowNetbird allows netbird interface traffic
|
||||||
AllowNetbird() error
|
AllowNetbird() error
|
||||||
|
|
||||||
@ -91,7 +95,7 @@ type Manager interface {
|
|||||||
SetLegacyManagement(legacy bool) error
|
SetLegacyManagement(legacy bool) error
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
Reset() error
|
Reset(stateManager *statemanager.Manager) error
|
||||||
|
|
||||||
// Flush the changes to firewall controller
|
// Flush the changes to firewall controller
|
||||||
Flush() error
|
Flush() error
|
||||||
|
@ -17,7 +17,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -56,13 +55,6 @@ type AclManager struct {
|
|||||||
rules map[string]*Rule
|
rules map[string]*Rule
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
|
||||||
type iFaceMapper interface {
|
|
||||||
Name() string
|
|
||||||
Address() iface.WGAddress
|
|
||||||
IsUserspaceBind() bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
||||||
// sConn is used for creating sets and adding/removing elements from them
|
// sConn is used for creating sets and adding/removing elements from them
|
||||||
// it's differ then rConn (which does create new conn for each flush operation)
|
// it's differ then rConn (which does create new conn for each flush operation)
|
||||||
@ -70,10 +62,10 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
|||||||
// overloads netlink with high amount of rules ( > 10000)
|
// overloads netlink with high amount of rules ( > 10000)
|
||||||
sConn, err := nftables.New(nftables.AsLasting())
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("create nf conn: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m := &AclManager{
|
return &AclManager{
|
||||||
rConn: &nftables.Conn{},
|
rConn: &nftables.Conn{},
|
||||||
sConn: sConn,
|
sConn: sConn,
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
@ -82,14 +74,12 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
|
|||||||
|
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
rules: make(map[string]*Rule),
|
rules: make(map[string]*Rule),
|
||||||
}
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
err = m.createDefaultChains()
|
func (m *AclManager) init(workTable *nftables.Table) error {
|
||||||
if err != nil {
|
m.workTable = workTable
|
||||||
return nil, err
|
return m.createDefaultChains()
|
||||||
}
|
|
||||||
|
|
||||||
return m, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
|
@ -14,6 +14,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -24,6 +26,13 @@ const (
|
|||||||
chainNameInput = "INPUT"
|
chainNameInput = "INPUT"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
|
type iFaceMapper interface {
|
||||||
|
Name() string
|
||||||
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
}
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
@ -35,30 +44,68 @@ type Manager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
rConn: &nftables.Conn{},
|
rConn: &nftables.Conn{},
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
workTable, err := m.createWorkTable()
|
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.router, err = newRouter(context, workTable, wgIface)
|
var err error
|
||||||
|
m.router, err = newRouter(workTable, wgIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Init nftables firewall manager
|
||||||
|
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||||
|
workTable, err := m.createWorkTable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create work table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.router.init(workTable); err != nil {
|
||||||
|
return fmt.Errorf("router init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.aclManager.init(workTable); err != nil {
|
||||||
|
// TODO: cleanup router
|
||||||
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
|
// We only need to record minimal interface state for potential recreation.
|
||||||
|
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||||
|
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||||
|
// cleanup using Reset() without needing to store specific rules.
|
||||||
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
|
InterfaceState: &InterfaceState{
|
||||||
|
NameStr: m.wgIface.Name(),
|
||||||
|
WGAddress: m.wgIface.Address(),
|
||||||
|
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// persist early
|
||||||
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddPeerFiltering rule to the firewall
|
// AddPeerFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
@ -203,48 +250,80 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
chains, err := m.rConn.ListChains()
|
if err := m.resetNetbirdInputRules(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("reset netbird input rules: %v", err)
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.router.Reset(); err != nil {
|
||||||
|
return fmt.Errorf("reset router: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.cleanupNetbirdTables(); err != nil {
|
||||||
|
return fmt.Errorf("cleanup netbird tables: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
|
return fmt.Errorf("delete state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) resetNetbirdInputRules() error {
|
||||||
|
chains, err := m.rConn.ListChains()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.deleteNetbirdInputRules(chains)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
// delete Netbird allow input traffic rule if it exists
|
|
||||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||||
rules, err := m.rConn.GetRules(c.Table, c)
|
rules, err := m.rConn.GetRules(c.Table, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, r := range rules {
|
|
||||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
m.deleteMatchingRules(rules)
|
||||||
if err := m.rConn.DelRule(r); err != nil {
|
}
|
||||||
log.Errorf("delete rule: %v", err)
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
|
||||||
|
for _, r := range rules {
|
||||||
|
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||||
|
if err := m.rConn.DelRule(r); err != nil {
|
||||||
|
log.Errorf("delete rule: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.router.Reset(); err != nil {
|
func (m *Manager) cleanupNetbirdTables() error {
|
||||||
return fmt.Errorf("reset forward rules: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tables, err := m.rConn.ListTables()
|
tables, err := m.rConn.ListTables()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of tables: %w", err)
|
return fmt.Errorf("list tables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == tableNameNetbird {
|
if t.Name == tableNameNetbird {
|
||||||
m.rConn.DelTable(t)
|
m.rConn.DelTable(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
return m.rConn.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
// Flush rule/chain/set operations from the buffer
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@ -58,12 +57,13 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
|||||||
func TestNftablesManager(t *testing.T) {
|
func TestNftablesManager(t *testing.T) {
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = manager.Reset()
|
err = manager.Reset(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
@ -169,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
// established rule remains
|
// established rule remains
|
||||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,12 +192,13 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(context.Background(), mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
if err := manager.Reset(nil); err != nil {
|
||||||
t.Errorf("clear the manager state: %v", err)
|
t.Errorf("clear the manager state: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
@ -2,7 +2,6 @@ package nftables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -40,8 +39,6 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
ctx context.Context
|
|
||||||
stop context.CancelFunc
|
|
||||||
conn *nftables.Conn
|
conn *nftables.Conn
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
filterTable *nftables.Table
|
filterTable *nftables.Table
|
||||||
@ -54,12 +51,8 @@ type router struct {
|
|||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
|
|
||||||
r := &router{
|
r := &router{
|
||||||
ctx: ctx,
|
|
||||||
stop: cancel,
|
|
||||||
conn: &nftables.Conn{},
|
conn: &nftables.Conn{},
|
||||||
workTable: workTable,
|
workTable: workTable,
|
||||||
chains: make(map[string]*nftables.Chain),
|
chains: make(map[string]*nftables.Chain),
|
||||||
@ -78,20 +71,25 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
|
|||||||
if errors.Is(err, errFilterTableNotFound) {
|
if errors.Is(err, errFilterTableNotFound) {
|
||||||
log.Warnf("table 'filter' not found for forward rules")
|
log.Warnf("table 'filter' not found for forward rules")
|
||||||
} else {
|
} else {
|
||||||
return nil, err
|
return nil, fmt.Errorf("load filter table: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.removeAcceptForwardRules()
|
return r, nil
|
||||||
if err != nil {
|
}
|
||||||
|
|
||||||
|
func (r *router) init(workTable *nftables.Table) error {
|
||||||
|
r.workTable = workTable
|
||||||
|
|
||||||
|
if err := r.removeAcceptForwardRules(); err != nil {
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.createContainers()
|
if err := r.createContainers(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("create containers: %w", err)
|
||||||
log.Errorf("failed to create containers for route: %s", err)
|
|
||||||
}
|
}
|
||||||
return r, err
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset cleans existing nftables default forward rules from the system
|
// Reset cleans existing nftables default forward rules from the system
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@ -40,8 +39,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range test.InsertRuleTestCases {
|
for _, testCase := range test.InsertRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
manager, err := newRouter(table, ifaceMock)
|
||||||
require.NoError(t, err, "failed to create router")
|
require.NoError(t, err, "failed to create router")
|
||||||
|
require.NoError(t, manager.init(table))
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
@ -142,8 +142,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range test.RemoveRuleTestCases {
|
for _, testCase := range test.RemoveRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := newRouter(context.TODO(), table, ifaceMock)
|
manager, err := newRouter(table, ifaceMock)
|
||||||
require.NoError(t, err, "failed to create router")
|
require.NoError(t, err, "failed to create router")
|
||||||
|
require.NoError(t, manager.init(table))
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
@ -210,8 +211,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
r, err := newRouter(workTable, ifaceMock)
|
||||||
require.NoError(t, err, "Failed to create router")
|
require.NoError(t, err, "Failed to create router")
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
|
||||||
defer func(r *router) {
|
defer func(r *router) {
|
||||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||||
@ -376,8 +378,9 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r, err := newRouter(context.Background(), workTable, ifaceMock)
|
r, err := newRouter(workTable, ifaceMock)
|
||||||
require.NoError(t, err, "Failed to create router")
|
require.NoError(t, err, "Failed to create router")
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||||
|
47
client/firewall/nftables/state_linux.go
Normal file
47
client/firewall/nftables/state_linux.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InterfaceState struct {
|
||||||
|
NameStr string `json:"name"`
|
||||||
|
WGAddress iface.WGAddress `json:"wg_address"`
|
||||||
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) Name() string {
|
||||||
|
return i.NameStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) Address() device.WGAddress {
|
||||||
|
return i.WGAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||||
|
return i.UserspaceBind
|
||||||
|
}
|
||||||
|
|
||||||
|
type ShutdownState struct {
|
||||||
|
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "nftables_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
nft, err := Create(s.InterfaceState)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create nftables manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := nft.Reset(nil); err != nil {
|
||||||
|
return fmt.Errorf("reset nftables manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@ -11,7 +13,7 @@ func (m *Manager) Reset() error {
|
|||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Reset()
|
return m.nativeFirewall.Reset(stateManager)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
type action string
|
type action string
|
||||||
@ -17,7 +19,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset() error {
|
func (m *Manager) Reset(*statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const layerTypeAll = 0
|
const layerTypeAll = 0
|
||||||
@ -97,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Init(*statemanager.Manager) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
return false
|
return false
|
||||||
@ -190,7 +195,7 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
|
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
return nil, errRouteNotSupported
|
return nil, errRouteNotSupported
|
||||||
}
|
}
|
||||||
|
@ -259,7 +259,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.Reset()
|
err = m.Reset(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
t.Errorf("failed to reset Manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -330,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = m.Reset(); err != nil {
|
if err = m.Reset(nil); err != nil {
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
t.Errorf("failed to reset Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -396,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
if err := manager.Reset(nil); err != nil {
|
||||||
t.Errorf("clear the manager state: %v", err)
|
t.Errorf("clear the manager state: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -52,13 +51,13 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func(fw manager.Manager) {
|
defer func(fw manager.Manager) {
|
||||||
_ = fw.Reset()
|
_ = fw.Reset(nil)
|
||||||
}(fw)
|
}(fw)
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
@ -345,13 +344,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
|
fw, err := firewall.NewFirewall(ifaceMock, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func(fw manager.Manager) {
|
defer func(fw manager.Manager) {
|
||||||
_ = fw.Reset()
|
_ = fw.Reset(nil)
|
||||||
}(fw)
|
}(fw)
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RunWithProbes runs the client's main logic with probes attached
|
// RunWithProbes runs the client's main logic with probes attached
|
||||||
func (c *ConnectClient) RunWithProbes(
|
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
|
||||||
probes *ProbeHolder,
|
|
||||||
runningChan chan error,
|
|
||||||
) error {
|
|
||||||
return c.run(MobileDependency{}, probes, runningChan)
|
return c.run(MobileDependency{}, probes, runningChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,11 +101,7 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
return c.run(mobileDependency, nil, nil)
|
return c.run(mobileDependency, nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) run(
|
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
|
||||||
mobileDependency MobileDependency,
|
|
||||||
probes *ProbeHolder,
|
|
||||||
runningChan chan error,
|
|
||||||
) error {
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||||
|
@ -533,6 +533,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// persist dns state right away
|
||||||
|
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||||
|
l.Errorf("Failed to persist dns state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||||
s.addHostRootZone()
|
s.addHostRootZone()
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@ -366,7 +367,7 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("create wg interface: %w", err)
|
return fmt.Errorf("create wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
}
|
}
|
||||||
@ -1167,7 +1168,7 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
err := e.firewall.Reset()
|
err := e.firewall.Reset(e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to reset firewall: %s", err)
|
log.Warnf("failed to reset firewall: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package refcounter
|
package refcounter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -70,6 +71,19 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadData loads the data from the existing counter
|
||||||
|
func (rm *Counter[Key, I, O]) LoadData(
|
||||||
|
existingCounter *Counter[Key, I, O],
|
||||||
|
) {
|
||||||
|
rm.refCountMu.Lock()
|
||||||
|
defer rm.refCountMu.Unlock()
|
||||||
|
rm.idMu.Lock()
|
||||||
|
defer rm.idMu.Unlock()
|
||||||
|
|
||||||
|
rm.refCountMap = existingCounter.refCountMap
|
||||||
|
rm.idMap = existingCounter.idMap
|
||||||
|
}
|
||||||
|
|
||||||
// Get retrieves the current reference count and associated data for a key.
|
// Get retrieves the current reference count and associated data for a key.
|
||||||
// If the key doesn't exist, it returns a zero value Ref and false.
|
// If the key doesn't exist, it returns a zero value Ref and false.
|
||||||
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||||
@ -201,6 +215,32 @@ func (rm *Counter[Key, I, O]) Clear() {
|
|||||||
clear(rm.idMap)
|
clear(rm.idMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||||
|
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||||
|
IDMap map[string][]Key `json:"idMap"`
|
||||||
|
}{
|
||||||
|
RefCountMap: rm.refCountMap,
|
||||||
|
IDMap: rm.idMap,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
|
||||||
|
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
||||||
|
var temp struct {
|
||||||
|
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||||
|
IDMap map[string][]Key `json:"idMap"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rm.refCountMap = temp.RefCountMap
|
||||||
|
rm.idMap = temp.IDMap
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getCallerInfo(depth int, maxDepth int) (string, bool) {
|
func getCallerInfo(depth int, maxDepth int) (string, bool) {
|
||||||
if depth >= maxDepth {
|
if depth >= maxDepth {
|
||||||
return "", false
|
return "", false
|
||||||
|
@ -1,30 +1,15 @@
|
|||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type RouteEntry struct {
|
|
||||||
Prefix netip.Prefix `json:"prefix"`
|
|
||||||
Nexthop Nexthop `json:"nexthop"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"`
|
Counter *ExclusionCounter `json:"counter,omitempty"`
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
|
||||||
|
|
||||||
func NewShutdownState() *ShutdownState {
|
|
||||||
return &ShutdownState{
|
|
||||||
Routes: make(map[netip.Prefix]RouteEntry),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@ -32,50 +17,16 @@ func (s *ShutdownState) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
if s.Counter == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
sysops := NewSysOps(nil, nil)
|
sysops := NewSysOps(nil, nil)
|
||||||
var merr *multierror.Error
|
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
||||||
|
sysops.refCounter.LoadData(s.Counter)
|
||||||
|
|
||||||
s.mu.RLock()
|
return sysops.refCounter.Flush()
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
for _, route := range s.Routes {
|
|
||||||
if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
s.Routes[prefix] = RouteEntry{
|
|
||||||
Prefix: prefix,
|
|
||||||
Nexthop: nexthop,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
delete(s.Routes, prefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalJSON ensures that empty routes are marshaled as null
|
|
||||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
if len(s.Routes) == 0 {
|
|
||||||
return json.Marshal(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.Marshal(s.Routes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
|
|
||||||
return json.Unmarshal(data, &s.Routes)
|
|
||||||
}
|
}
|
||||||
|
@ -57,14 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
|||||||
return nexthop, refcounter.ErrIgnore
|
return nexthop, refcounter.ErrIgnore
|
||||||
}
|
}
|
||||||
|
|
||||||
r.updateState(stateManager, prefix, nexthop)
|
r.updateState(stateManager)
|
||||||
|
|
||||||
return nexthop, err
|
return nexthop, err
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, nexthop Nexthop) error {
|
func(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
// remove from state even if we have trouble removing it from the route table
|
// remove from state even if we have trouble removing it from the route table
|
||||||
// it could be already gone
|
// it could be already gone
|
||||||
r.removeFromState(stateManager, prefix)
|
r.updateState(stateManager)
|
||||||
|
|
||||||
return r.removeFromRouteTable(prefix, nexthop)
|
return r.removeFromRouteTable(prefix, nexthop)
|
||||||
},
|
},
|
||||||
@ -75,24 +75,16 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
|||||||
return r.setupHooks(initAddresses)
|
return r.setupHooks(initAddresses)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) {
|
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
|
||||||
state := getState(stateManager)
|
state := getState(stateManager)
|
||||||
state.UpdateRoute(prefix, nexthop)
|
|
||||||
|
state.Counter = r.refCounter
|
||||||
|
|
||||||
if err := stateManager.UpdateState(state); err != nil {
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) {
|
|
||||||
state := getState(stateManager)
|
|
||||||
state.RemoveRoute(prefix)
|
|
||||||
|
|
||||||
if err := stateManager.UpdateState(state); err != nil {
|
|
||||||
log.Errorf("Failed to update state: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
||||||
if r.refCounter == nil {
|
if r.refCounter == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -107,7 +99,7 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
log.Errorf("failed to delete state: %v", err)
|
return fmt.Errorf("delete state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -546,7 +538,7 @@ func getState(stateManager *statemanager.Manager) *ShutdownState {
|
|||||||
if state := stateManager.GetState(shutdownState); state != nil {
|
if state := stateManager.GetState(shutdownState); state != nil {
|
||||||
shutdownState = state.(*ShutdownState)
|
shutdownState = state.(*ShutdownState)
|
||||||
} else {
|
} else {
|
||||||
shutdownState = NewShutdownState()
|
shutdownState = &ShutdownState{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return shutdownState
|
return shutdownState
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetDefaultStatePath returns the path to the state file based on the operating system
|
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||||
@ -27,7 +27,7 @@ func GetDefaultStatePath() string {
|
|||||||
|
|
||||||
dir := filepath.Dir(path)
|
dir := filepath.Dir(path)
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
@ -21,11 +20,7 @@ import (
|
|||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
@ -848,31 +843,3 @@ func sendTerminalNotification() error {
|
|||||||
|
|
||||||
return wallCmd.Wait()
|
return wallCmd.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required.
|
|
||||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
|
||||||
func restoreResidualState(ctx context.Context) error {
|
|
||||||
path := statemanager.GetDefaultStatePath()
|
|
||||||
if path == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mgr := statemanager.New(path)
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
// register the states we are interested in restoring
|
|
||||||
// this will also allow each subsystem to record its own state
|
|
||||||
mgr.RegisterState(&dns.ShutdownState{})
|
|
||||||
mgr.RegisterState(&systemops.ShutdownState{})
|
|
||||||
|
|
||||||
if err := mgr.PerformCleanup(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := mgr.PersistState(ctx); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
37
client/server/state.go
Normal file
37
client/server/state.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required.
|
||||||
|
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||||
|
func restoreResidualState(ctx context.Context) error {
|
||||||
|
path := statemanager.GetDefaultStatePath()
|
||||||
|
if path == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := statemanager.New(path)
|
||||||
|
|
||||||
|
// register the states we are interested in restoring
|
||||||
|
registerStates(mgr)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
if err := mgr.PerformCleanup(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// persist state regardless of cleanup outcome. It could've succeeded partially
|
||||||
|
if err := mgr.PersistState(ctx); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
14
client/server/state_generic.go
Normal file
14
client/server/state_generic.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func registerStates(mgr *statemanager.Manager) {
|
||||||
|
mgr.RegisterState(&dns.ShutdownState{})
|
||||||
|
mgr.RegisterState(&systemops.ShutdownState{})
|
||||||
|
}
|
18
client/server/state_linux.go
Normal file
18
client/server/state_linux.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/iptables"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func registerStates(mgr *statemanager.Manager) {
|
||||||
|
mgr.RegisterState(&dns.ShutdownState{})
|
||||||
|
mgr.RegisterState(&systemops.ShutdownState{})
|
||||||
|
mgr.RegisterState(&nftables.ShutdownState{})
|
||||||
|
mgr.RegisterState(&iptables.ShutdownState{})
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user