[client] Add state handling cmdline options (#2821)

This commit is contained in:
Viktor Liu
2024-12-03 16:07:18 +01:00
committed by GitHub
parent 8866394eb6
commit e5d42bc963
7 changed files with 1237 additions and 173 deletions

View File

@ -19,6 +19,11 @@ import (
"github.com/netbirdio/netbird/util"
)
const (
errStateNotRegistered = "state %s not registered"
errLoadStateFile = "load state file: %w"
)
// State interface defines the methods that all state types must implement
type State interface {
Name() string
@ -159,7 +164,7 @@ func (m *Manager) setState(name string, state State) error {
defer m.mu.Unlock()
if _, exists := m.states[name]; !exists {
return fmt.Errorf("state %s not registered", name)
return fmt.Errorf(errStateNotRegistered, name)
}
m.states[name] = state
@ -168,6 +173,63 @@ func (m *Manager) setState(name string, state State) error {
return nil
}
// DeleteStateByName handles deletion of states without cleanup.
// It doesn't require the state to be registered.
func (m *Manager) DeleteStateByName(stateName string) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil
}
if _, exists := rawStates[stateName]; !exists {
return fmt.Errorf("state %s not found", stateName)
}
// Mark state as deleted by setting it to nil and marking it dirty
m.states[stateName] = nil
m.dirty[stateName] = struct{}{}
return nil
}
// DeleteAllStates removes all states.
func (m *Manager) DeleteAllStates() (int, error) {
if m == nil {
return 0, nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return 0, fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return 0, nil
}
count := len(rawStates)
// Mark all states as deleted and dirty
for name := range rawStates {
m.states[name] = nil
m.dirty[name] = struct{}{}
}
return count, nil
}
func (m *Manager) periodicStateSave(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
@ -221,7 +283,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
}
}
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
log.Debugf("persisted states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
clear(m.dirty)
@ -229,7 +291,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
}
// loadStateFile reads and unmarshals the state file into a map of raw JSON messages
func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) {
func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) {
data, err := os.ReadFile(m.filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
@ -241,11 +303,13 @@ func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) {
var rawStates map[string]json.RawMessage
if err := json.Unmarshal(data, &rawStates); err != nil {
log.Warn("State file appears to be corrupted, attempting to delete it")
if err := os.Remove(m.filePath); err != nil {
log.Errorf("Failed to delete corrupted state file: %v", err)
} else {
log.Info("State file deleted")
if deleteCorrupt {
log.Warn("State file appears to be corrupted, attempting to delete it", err)
if err := os.Remove(m.filePath); err != nil {
log.Errorf("Failed to delete corrupted state file: %v", err)
} else {
log.Info("State file deleted")
}
}
return nil, fmt.Errorf("unmarshal states: %w", err)
}
@ -257,7 +321,7 @@ func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) {
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
stateType, ok := m.stateTypes[name]
if !ok {
return nil, fmt.Errorf("state %s not registered", name)
return nil, fmt.Errorf(errStateNotRegistered, name)
}
if string(rawState) == "null" {
@ -281,7 +345,7 @@ func (m *Manager) LoadState(state State) error {
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile()
rawStates, err := m.loadStateFile(false)
if err != nil {
return err
}
@ -308,6 +372,84 @@ func (m *Manager) LoadState(state State) error {
return nil
}
// cleanupSingleState handles the cleanup of a specific state and returns any error.
// The caller must hold the mutex.
func (m *Manager) cleanupSingleState(name string, rawState json.RawMessage) error {
// For unregistered states, preserve the raw JSON
if _, registered := m.stateTypes[name]; !registered {
m.states[name] = &RawState{data: rawState}
return nil
}
// Load the state
loadedState, err := m.loadSingleRawState(name, rawState)
if err != nil {
return err
}
if loadedState == nil {
return nil
}
// Check if state supports cleanup
cleanableState, isCleanable := loadedState.(CleanableState)
if !isCleanable {
// If it doesn't support cleanup, keep it as-is
m.states[name] = loadedState
return nil
}
// Perform cleanup
log.Infof("cleaning up state %s", name)
if err := cleanableState.Cleanup(); err != nil {
// On cleanup error, preserve the state
m.states[name] = loadedState
return fmt.Errorf("cleanup state: %w", err)
}
// Successfully cleaned up - mark for deletion
m.states[name] = nil
m.dirty[name] = struct{}{}
return nil
}
// CleanupStateByName loads and cleans up a specific state by name if it implements CleanableState.
// Returns an error if the state doesn't exist, isn't registered, or cleanup fails.
func (m *Manager) CleanupStateByName(name string) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
// Check if state is registered
if _, registered := m.stateTypes[name]; !registered {
return fmt.Errorf(errStateNotRegistered, name)
}
// Load raw states from file
rawStates, err := m.loadStateFile(false)
if err != nil {
return err
}
if rawStates == nil {
return nil
}
// Check if state exists in file
rawState, exists := rawStates[name]
if !exists {
return nil
}
if err := m.cleanupSingleState(name, rawState); err != nil {
return fmt.Errorf("%s: %w", name, err)
}
return nil
}
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
// Unregistered states are preserved in their original state.
func (m *Manager) PerformCleanup() error {
@ -319,10 +461,9 @@ func (m *Manager) PerformCleanup() error {
defer m.mu.Unlock()
// Load raw states from file
rawStates, err := m.loadStateFile()
rawStates, err := m.loadStateFile(true)
if err != nil {
log.Warnf("Failed to load state during cleanup: %v", err)
return err
return fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil
@ -332,47 +473,38 @@ func (m *Manager) PerformCleanup() error {
// Process each state in the file
for name, rawState := range rawStates {
// For unregistered states, preserve the raw JSON
if _, registered := m.stateTypes[name]; !registered {
m.states[name] = &RawState{data: rawState}
continue
}
// Load the registered state
loadedState, err := m.loadSingleRawState(name, rawState)
if err != nil {
merr = multierror.Append(merr, err)
continue
}
if loadedState == nil {
continue
}
// Check if state supports cleanup
cleanableState, isCleanable := loadedState.(CleanableState)
if !isCleanable {
// If it doesn't support cleanup, keep it as-is
m.states[name] = loadedState
continue
}
// Perform cleanup for cleanable states
log.Infof("client was not shut down properly, cleaning up %s", name)
if err := cleanableState.Cleanup(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
// On cleanup error, preserve the state
m.states[name] = loadedState
} else {
// Successfully cleaned up - mark for deletion
m.states[name] = nil
m.dirty[name] = struct{}{}
if err := m.cleanupSingleState(name, rawState); err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", name, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// GetSavedStateNames returns all state names that are currently saved in the state file.
func (m *Manager) GetSavedStateNames() ([]string, error) {
if m == nil {
return nil, nil
}
rawStates, err := m.loadStateFile(false)
if err != nil {
return nil, fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil, nil
}
var states []string
for name, state := range rawStates {
if len(state) != 0 && string(state) != "null" {
states = append(states, name)
}
}
return states, nil
}
func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte
var err error