mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 09:47:49 +02:00
[client] Persist route selection (#2810)
This commit is contained in:
parent
ecb44ff306
commit
5142dc52c1
@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.ips = temp.IPs
|
s.ips = temp.IPs
|
||||||
|
|
||||||
|
if temp.IPs == nil {
|
||||||
|
temp.IPs = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.ipsets = temp.IPSets
|
s.ipsets = temp.IPSets
|
||||||
|
|
||||||
|
if temp.IPSets == nil {
|
||||||
|
temp.IPSets = make(map[string]*ipList)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -349,8 +349,17 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
e.routeManager = routemanager.NewManager(
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
|
e.ctx,
|
||||||
|
e.config.WgPrivateKey.PublicKey().String(),
|
||||||
|
e.config.DNSRouteInterval,
|
||||||
|
e.wgInterface,
|
||||||
|
e.statusRecorder,
|
||||||
|
e.relayManager,
|
||||||
|
initialRoutes,
|
||||||
|
e.stateManager,
|
||||||
|
)
|
||||||
|
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to initialize route manager: %s", err)
|
log.Errorf("Failed to initialize route manager: %s", err)
|
||||||
} else {
|
} else {
|
||||||
|
@ -245,12 +245,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
nil)
|
nil)
|
||||||
|
|
||||||
wgIface := &iface.MockWGIface{
|
wgIface := &iface.MockWGIface{
|
||||||
|
NameFunc: func() string { return "utun102" },
|
||||||
RemovePeerFunc: func(peerKey string) error {
|
RemovePeerFunc: func(peerKey string) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
engine.wgInterface = wgIface
|
engine.wgInterface = wgIface
|
||||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
|
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
|
||||||
|
_, _, err = engine.routeManager.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,7 @@ import (
|
|||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||||
TriggerSelection(route.HAMap)
|
TriggerSelection(route.HAMap)
|
||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
@ -59,6 +59,7 @@ type DefaultManager struct {
|
|||||||
routeRefCounter *refcounter.RouteRefCounter
|
routeRefCounter *refcounter.RouteRefCounter
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(
|
func NewManager(
|
||||||
@ -69,6 +70,7 @@ func NewManager(
|
|||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
relayMgr *relayClient.Manager,
|
relayMgr *relayClient.Manager,
|
||||||
initialRoutes []*route.Route,
|
initialRoutes []*route.Route,
|
||||||
|
stateManager *statemanager.Manager,
|
||||||
) *DefaultManager {
|
) *DefaultManager {
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
mCTX, cancel := context.WithCancel(ctx)
|
||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
@ -80,12 +82,12 @@ func NewManager(
|
|||||||
dnsRouteInterval: dnsRouteInterval,
|
dnsRouteInterval: dnsRouteInterval,
|
||||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||||
relayMgr: relayMgr,
|
relayMgr: relayMgr,
|
||||||
routeSelector: routeselector.NewRouteSelector(),
|
|
||||||
sysOps: sysOps,
|
sysOps: sysOps,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
pubKey: pubKey,
|
pubKey: pubKey,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
|
stateManager: stateManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
dm.routeRefCounter = refcounter.New(
|
dm.routeRefCounter = refcounter.New(
|
||||||
@ -121,7 +123,7 @@ func NewManager(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Init sets up the routing
|
// Init sets up the routing
|
||||||
func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
if nbnet.CustomRoutingDisabled() {
|
if nbnet.CustomRoutingDisabled() {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
@ -137,14 +139,38 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook
|
|||||||
|
|
||||||
ips := resolveURLsToIPs(initialAddresses)
|
ips := resolveURLsToIPs(initialAddresses)
|
||||||
|
|
||||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager)
|
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.routeSelector = m.initSelector()
|
||||||
|
|
||||||
log.Info("Routing setup complete")
|
log.Info("Routing setup complete")
|
||||||
return beforePeerHook, afterPeerHook, nil
|
return beforePeerHook, afterPeerHook, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||||
|
var state *SelectorState
|
||||||
|
m.stateManager.RegisterState(state)
|
||||||
|
|
||||||
|
// restore selector state if it exists
|
||||||
|
if err := m.stateManager.LoadState(state); err != nil {
|
||||||
|
log.Warnf("failed to load state: %v", err)
|
||||||
|
return routeselector.NewRouteSelector()
|
||||||
|
}
|
||||||
|
|
||||||
|
if state := m.stateManager.GetState(state); state != nil {
|
||||||
|
if selector, ok := state.(*SelectorState); ok {
|
||||||
|
return (*routeselector.RouteSelector)(selector)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("failed to convert state with type %T to SelectorState", state)
|
||||||
|
}
|
||||||
|
|
||||||
|
return routeselector.NewRouteSelector()
|
||||||
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||||
var err error
|
var err error
|
||||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||||
@ -252,6 +278,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||||
|
@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
|
|
||||||
statusRecorder := peer.NewRecorder("https://mgm")
|
statusRecorder := peer.NewRecorder("https://mgm")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
|
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
|
||||||
|
|
||||||
_, _, err = routeManager.Init(nil)
|
_, _, err = routeManager.Init()
|
||||||
|
|
||||||
require.NoError(t, err, "should init route manager")
|
require.NoError(t, err, "should init route manager")
|
||||||
defer routeManager.Stop(nil)
|
defer routeManager.Stop(nil)
|
||||||
|
@ -21,7 +21,7 @@ type MockManager struct {
|
|||||||
StopFunc func(manager *statemanager.Manager)
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) {
|
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,11 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LoadData loads the data from the existing counter
|
// LoadData loads the data from the existing counter
|
||||||
|
// The passed counter should not be used any longer after calling this function.
|
||||||
func (rm *Counter[Key, I, O]) LoadData(
|
func (rm *Counter[Key, I, O]) LoadData(
|
||||||
existingCounter *Counter[Key, I, O],
|
existingCounter *Counter[Key, I, O],
|
||||||
) {
|
) {
|
||||||
rm.mu.Lock()
|
rm.mu.Lock()
|
||||||
defer rm.mu.Unlock()
|
defer rm.mu.Unlock()
|
||||||
|
existingCounter.mu.Lock()
|
||||||
|
defer existingCounter.mu.Unlock()
|
||||||
|
|
||||||
rm.refCountMap = existingCounter.refCountMap
|
rm.refCountMap = existingCounter.refCountMap
|
||||||
rm.idMap = existingCounter.idMap
|
rm.idMap = existingCounter.idMap
|
||||||
@ -231,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
|||||||
|
|
||||||
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
|
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
|
||||||
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
||||||
|
rm.mu.Lock()
|
||||||
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
var temp struct {
|
var temp struct {
|
||||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||||
IDMap map[string][]Key `json:"idMap"`
|
IDMap map[string][]Key `json:"idMap"`
|
||||||
@ -241,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
|||||||
rm.refCountMap = temp.RefCountMap
|
rm.refCountMap = temp.RefCountMap
|
||||||
rm.idMap = temp.IDMap
|
rm.idMap = temp.IDMap
|
||||||
|
|
||||||
|
if temp.RefCountMap == nil {
|
||||||
|
temp.RefCountMap = map[Key]Ref[O]{}
|
||||||
|
}
|
||||||
|
if temp.IDMap == nil {
|
||||||
|
temp.IDMap = map[string][]Key{}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
19
client/internal/routemanager/state.go
Normal file
19
client/internal/routemanager/state.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SelectorState routeselector.RouteSelector
|
||||||
|
|
||||||
|
func (s *SelectorState) Name() string {
|
||||||
|
return "routeselector_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SelectorState) MarshalJSON() ([]byte, error) {
|
||||||
|
return (*routeselector.RouteSelector)(s).MarshalJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SelectorState) UnmarshalJSON(data []byte) error {
|
||||||
|
return (*routeselector.RouteSelector)(s).UnmarshalJSON(data)
|
||||||
|
}
|
@ -1,8 +1,10 @@
|
|||||||
package routeselector
|
package routeselector
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@ -12,6 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RouteSelector struct {
|
type RouteSelector struct {
|
||||||
|
mu sync.RWMutex
|
||||||
selectedRoutes map[route.NetID]struct{}
|
selectedRoutes map[route.NetID]struct{}
|
||||||
selectAll bool
|
selectAll bool
|
||||||
}
|
}
|
||||||
@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector {
|
|||||||
|
|
||||||
// SelectRoutes updates the selected routes based on the provided route IDs.
|
// SelectRoutes updates the selected routes based on the provided route IDs.
|
||||||
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
|
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
if !appendRoute {
|
if !appendRoute {
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
@ -46,6 +52,9 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
|||||||
|
|
||||||
// SelectAllRoutes sets the selector to select all routes.
|
// SelectAllRoutes sets the selector to select all routes.
|
||||||
func (rs *RouteSelector) SelectAllRoutes() {
|
func (rs *RouteSelector) SelectAllRoutes() {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
rs.selectAll = true
|
rs.selectAll = true
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
@ -53,6 +62,9 @@ func (rs *RouteSelector) SelectAllRoutes() {
|
|||||||
// DeselectRoutes removes specific routes from the selection.
|
// DeselectRoutes removes specific routes from the selection.
|
||||||
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
||||||
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
|
|||||||
|
|
||||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||||
func (rs *RouteSelector) DeselectAllRoutes() {
|
func (rs *RouteSelector) DeselectAllRoutes() {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSelected checks if a specific route is selected.
|
// IsSelected checks if a specific route is selected.
|
||||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||||
|
rs.mu.RLock()
|
||||||
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
|||||||
|
|
||||||
// FilterSelected removes unselected routes from the provided map.
|
// FilterSelected removes unselected routes from the provided map.
|
||||||
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||||
|
rs.mu.RLock()
|
||||||
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
return maps.Clone(routes)
|
return maps.Clone(routes)
|
||||||
}
|
}
|
||||||
@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
|||||||
}
|
}
|
||||||
return filtered
|
return filtered
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the json.Marshaler interface
|
||||||
|
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
|
||||||
|
rs.mu.RLock()
|
||||||
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
|
return json.Marshal(struct {
|
||||||
|
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||||
|
SelectAll bool `json:"select_all"`
|
||||||
|
}{
|
||||||
|
SelectAll: rs.selectAll,
|
||||||
|
SelectedRoutes: rs.selectedRoutes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the json.Unmarshaler interface
|
||||||
|
// If the JSON is empty or null, it will initialize like a NewRouteSelector.
|
||||||
|
func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
||||||
|
rs.mu.Lock()
|
||||||
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
|
// Check for null or empty JSON
|
||||||
|
if len(data) == 0 || string(data) == "null" {
|
||||||
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
|
rs.selectAll = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var temp struct {
|
||||||
|
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||||
|
SelectAll bool `json:"select_all"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rs.selectedRoutes = temp.SelectedRoutes
|
||||||
|
rs.selectAll = temp.SelectAll
|
||||||
|
|
||||||
|
if rs.selectedRoutes == nil {
|
||||||
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -22,9 +22,28 @@ import (
|
|||||||
// State interface defines the methods that all state types must implement
|
// State interface defines the methods that all state types must implement
|
||||||
type State interface {
|
type State interface {
|
||||||
Name() string
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanableState interface extends State with cleanup capability
|
||||||
|
type CleanableState interface {
|
||||||
|
State
|
||||||
Cleanup() error
|
Cleanup() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RawState wraps raw JSON data for unregistered states
|
||||||
|
type RawState struct {
|
||||||
|
data json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RawState) Name() string {
|
||||||
|
return "" // This is a placeholder implementation
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler to preserve the original JSON
|
||||||
|
func (r *RawState) MarshalJSON() ([]byte, error) {
|
||||||
|
return r.data, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Manager handles the persistence and management of various states
|
// Manager handles the persistence and management of various states
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@ -209,15 +228,15 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadState loads the existing state from the state file
|
// loadStateFile reads and unmarshals the state file into a map of raw JSON messages
|
||||||
func (m *Manager) loadState() error {
|
func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) {
|
||||||
data, err := os.ReadFile(m.filePath)
|
data, err := os.ReadFile(m.filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
log.Debug("state file does not exist")
|
log.Debug("state file does not exist")
|
||||||
return nil
|
return nil, nil // nolint:nilnil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("read state file: %w", err)
|
return nil, fmt.Errorf("read state file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var rawStates map[string]json.RawMessage
|
var rawStates map[string]json.RawMessage
|
||||||
@ -228,37 +247,69 @@ func (m *Manager) loadState() error {
|
|||||||
} else {
|
} else {
|
||||||
log.Info("State file deleted")
|
log.Info("State file deleted")
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unmarshal states: %w", err)
|
return nil, fmt.Errorf("unmarshal states: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
return rawStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
for name, rawState := range rawStates {
|
// loadSingleRawState unmarshals a raw state into a concrete state object
|
||||||
|
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
|
||||||
stateType, ok := m.stateTypes[name]
|
stateType, ok := m.stateTypes[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name))
|
return nil, fmt.Errorf("state %s not registered", name)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(rawState) == "null" {
|
if string(rawState) == "null" {
|
||||||
continue
|
return nil, nil //nolint:nilnil
|
||||||
}
|
}
|
||||||
|
|
||||||
statePtr := reflect.New(stateType).Interface().(State)
|
statePtr := reflect.New(stateType).Interface().(State)
|
||||||
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err))
|
return nil, fmt.Errorf("unmarshal state %s: %w", name, err)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.states[name] = statePtr
|
return statePtr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadState loads a specific state from the state file
|
||||||
|
func (m *Manager) LoadState(state State) error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
rawStates, err := m.loadStateFile()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if rawStates == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := state.Name()
|
||||||
|
rawState, exists := rawStates[name]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
loadedState, err := m.loadSingleRawState(name, rawState)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.states[name] = loadedState
|
||||||
|
if loadedState != nil {
|
||||||
log.Debugf("loaded state: %s", name)
|
log.Debugf("loaded state: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them.
|
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
|
||||||
// If the cleanup is successful, the state is marked for deletion.
|
// Unregistered states are preserved in their original state.
|
||||||
func (m *Manager) PerformCleanup() error {
|
func (m *Manager) PerformCleanup() error {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -267,22 +318,53 @@ func (m *Manager) PerformCleanup() error {
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
if err := m.loadState(); err != nil {
|
// Load raw states from file
|
||||||
|
rawStates, err := m.loadStateFile()
|
||||||
|
if err != nil {
|
||||||
log.Warnf("Failed to load state during cleanup: %v", err)
|
log.Warnf("Failed to load state during cleanup: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if rawStates == nil {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
for name, state := range m.states {
|
|
||||||
if state == nil {
|
// Process each state in the file
|
||||||
// If no state was found in the state file, we don't mark the state dirty nor return an error
|
for name, rawState := range rawStates {
|
||||||
|
// For unregistered states, preserve the raw JSON
|
||||||
|
if _, registered := m.stateTypes[name]; !registered {
|
||||||
|
m.states[name] = &RawState{data: rawState}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load the registered state
|
||||||
|
loadedState, err := m.loadSingleRawState(name, rawState)
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if loadedState == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if state supports cleanup
|
||||||
|
cleanableState, isCleanable := loadedState.(CleanableState)
|
||||||
|
if !isCleanable {
|
||||||
|
// If it doesn't support cleanup, keep it as-is
|
||||||
|
m.states[name] = loadedState
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform cleanup for cleanable states
|
||||||
log.Infof("client was not shut down properly, cleaning up %s", name)
|
log.Infof("client was not shut down properly, cleaning up %s", name)
|
||||||
if err := state.Cleanup(); err != nil {
|
if err := cleanableState.Cleanup(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
|
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
|
||||||
|
// On cleanup error, preserve the state
|
||||||
|
m.states[name] = loadedState
|
||||||
} else {
|
} else {
|
||||||
// mark for deletion on cleanup success
|
// Successfully cleaned up - mark for deletion
|
||||||
m.states[name] = nil
|
m.states[name] = nil
|
||||||
m.dirty[name] = struct{}{}
|
m.dirty[name] = struct{}{}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user