mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 02:50:43 +02:00
[client] Cleanup dns and route states on startup (#2757)
This commit is contained in:
81
client/internal/routemanager/systemops/state.go
Normal file
81
client/internal/routemanager/systemops/state.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
type RouteEntry struct {
|
||||
Prefix netip.Prefix `json:"prefix"`
|
||||
Nexthop Nexthop `json:"nexthop"`
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewShutdownState() *ShutdownState {
|
||||
return &ShutdownState{
|
||||
Routes: make(map[netip.Prefix]RouteEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "route_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
sysops := NewSysOps(nil, nil)
|
||||
var merr *multierror.Error
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for _, route := range s.Routes {
|
||||
if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.Routes[prefix] = RouteEntry{
|
||||
Prefix: prefix,
|
||||
Nexthop: nexthop,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.Routes, prefix)
|
||||
}
|
||||
|
||||
// MarshalJSON ensures that empty routes are marshaled as null
|
||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if len(s.Routes) == 0 {
|
||||
return json.Marshal(nil)
|
||||
}
|
||||
|
||||
return json.Marshal(s.Routes)
|
||||
}
|
||||
|
||||
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
|
||||
return json.Unmarshal(data, &s.Routes)
|
||||
}
|
@@ -9,14 +9,15 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -28,6 +29,10 @@ func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
|
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -30,7 +31,9 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||
@@ -53,9 +56,18 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
|
||||
// These errors are not critical, but also we should not track and try to remove the routes either.
|
||||
return nexthop, refcounter.ErrIgnore
|
||||
}
|
||||
|
||||
r.updateState(stateManager, prefix, nexthop)
|
||||
|
||||
return nexthop, err
|
||||
},
|
||||
r.removeFromRouteTable,
|
||||
func(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
// remove from state even if we have trouble removing it from the route table
|
||||
// it could be already gone
|
||||
r.removeFromState(stateManager, prefix)
|
||||
|
||||
return r.removeFromRouteTable(prefix, nexthop)
|
||||
},
|
||||
)
|
||||
|
||||
r.refCounter = refCounter
|
||||
@@ -63,7 +75,25 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
|
||||
return r.setupHooks(initAddresses)
|
||||
}
|
||||
|
||||
func (r *SysOps) cleanupRefCounter() error {
|
||||
func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) {
|
||||
state := getState(stateManager)
|
||||
state.UpdateRoute(prefix, nexthop)
|
||||
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) {
|
||||
state := getState(stateManager)
|
||||
state.RemoveRoute(prefix)
|
||||
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("Failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
||||
if r.refCounter == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -76,6 +106,10 @@ func (r *SysOps) cleanupRefCounter() error {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
}
|
||||
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
log.Errorf("failed to delete state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -506,3 +540,14 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
|
||||
// Return true if the longest matching prefix is from vpnRoutes
|
||||
return isVpn, longestPrefix
|
||||
}
|
||||
|
||||
func getState(stateManager *statemanager.Manager) *ShutdownState {
|
||||
var shutdownState *ShutdownState
|
||||
if state := stateManager.GetState(shutdownState); state != nil {
|
||||
shutdownState = state.(*ShutdownState)
|
||||
} else {
|
||||
shutdownState = NewShutdownState()
|
||||
}
|
||||
|
||||
return shutdownState
|
||||
}
|
||||
|
@@ -77,10 +77,10 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
|
||||
_, _, err = r.SetupRouting(nil)
|
||||
_, _, err = r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting())
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
@@ -403,10 +403,10 @@ func setupTestEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
_, _, err := r.SetupRouting(nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting())
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
|
@@ -9,17 +9,18 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.prefixes = make(map[netip.Prefix]struct{})
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
@@ -46,6 +47,18 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) notify() {
|
||||
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
|
||||
for prefix := range r.prefixes {
|
||||
prefixes = append(prefixes, prefix)
|
||||
}
|
||||
r.notifier.OnNewPrefixes(prefixes)
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
@@ -54,11 +67,3 @@ func EnableIPForwarding() error {
|
||||
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
|
||||
func (r *SysOps) notify() {
|
||||
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
|
||||
for prefix := range r.prefixes {
|
||||
prefixes = append(prefixes, prefix)
|
||||
}
|
||||
r.notifier.OnNewPrefixes(prefixes)
|
||||
}
|
||||
|
@@ -18,6 +18,7 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -85,10 +86,10 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
||||
if isLegacy() {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return r.setupRefCounter(initAddresses)
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
@@ -104,7 +105,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := r.CleanupRouting(); cleanErr != nil {
|
||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||
}
|
||||
}
|
||||
@@ -116,7 +117,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
|
||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
||||
setIsLegacy(true)
|
||||
return r.setupRefCounter(initAddresses)
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||
}
|
||||
@@ -128,9 +129,9 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
|
||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
if isLegacy() {
|
||||
return r.cleanupRefCounter()
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
|
@@ -13,15 +13,16 @@ import (
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses)
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return r.cleanupRefCounter()
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
|
@@ -22,6 +22,7 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -130,12 +131,12 @@ const (
|
||||
RouteDeleted
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses)
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return r.cleanupRefCounter()
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
|
Reference in New Issue
Block a user