mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 10:50:45 +01:00
391 lines
8.3 KiB
Go
391 lines
8.3 KiB
Go
package statemanager
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/fs"
|
|
"os"
|
|
"reflect"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/exp/maps"
|
|
|
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
"github.com/netbirdio/netbird/util"
|
|
)
|
|
|
|
// State interface defines the methods that all state types must implement
|
|
type State interface {
|
|
Name() string
|
|
}
|
|
|
|
// CleanableState interface extends State with cleanup capability
|
|
type CleanableState interface {
|
|
State
|
|
Cleanup() error
|
|
}
|
|
|
|
// RawState wraps raw JSON data for unregistered states
|
|
type RawState struct {
|
|
data json.RawMessage
|
|
}
|
|
|
|
func (r *RawState) Name() string {
|
|
return "" // This is a placeholder implementation
|
|
}
|
|
|
|
// MarshalJSON implements json.Marshaler to preserve the original JSON
|
|
func (r *RawState) MarshalJSON() ([]byte, error) {
|
|
return r.data, nil
|
|
}
|
|
|
|
// Manager handles the persistence and management of various states
|
|
type Manager struct {
|
|
mu sync.Mutex
|
|
cancel context.CancelFunc
|
|
done chan struct{}
|
|
|
|
filePath string
|
|
// holds the states that are registered with the manager and that are to be persisted
|
|
states map[string]State
|
|
// holds the state names that have been updated and need to be persisted with the next save
|
|
dirty map[string]struct{}
|
|
// holds the type information for each registered state
|
|
stateTypes map[string]reflect.Type
|
|
}
|
|
|
|
// New creates a new Manager instance
|
|
func New(filePath string) *Manager {
|
|
return &Manager{
|
|
filePath: filePath,
|
|
states: make(map[string]State),
|
|
dirty: make(map[string]struct{}),
|
|
stateTypes: make(map[string]reflect.Type),
|
|
}
|
|
}
|
|
|
|
// Start starts the state manager periodic save routine
|
|
func (m *Manager) Start() {
|
|
if m == nil {
|
|
return
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var ctx context.Context
|
|
ctx, m.cancel = context.WithCancel(context.Background())
|
|
m.done = make(chan struct{})
|
|
|
|
go m.periodicStateSave(ctx)
|
|
}
|
|
|
|
func (m *Manager) Stop(ctx context.Context) error {
|
|
if m == nil {
|
|
return nil
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.cancel == nil {
|
|
return nil
|
|
}
|
|
m.cancel()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-m.done:
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RegisterState registers a state with the manager but doesn't attempt to persist it.
|
|
// Pass an uninitialized state to register it.
|
|
func (m *Manager) RegisterState(state State) {
|
|
if m == nil {
|
|
return
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
name := state.Name()
|
|
m.states[name] = nil
|
|
m.stateTypes[name] = reflect.TypeOf(state).Elem()
|
|
}
|
|
|
|
// GetState returns the state for the given type
|
|
func (m *Manager) GetState(state State) State {
|
|
if m == nil {
|
|
return nil
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
return m.states[state.Name()]
|
|
}
|
|
|
|
// UpdateState updates the state in the manager and marks it as dirty for the next save.
|
|
// The state will be replaced with the new one.
|
|
func (m *Manager) UpdateState(state State) error {
|
|
if m == nil {
|
|
return nil
|
|
}
|
|
|
|
return m.setState(state.Name(), state)
|
|
}
|
|
|
|
// DeleteState removes the state from the manager and marks it as dirty for the next save.
|
|
// Pass an uninitialized state to delete it.
|
|
func (m *Manager) DeleteState(state State) error {
|
|
if m == nil {
|
|
return nil
|
|
}
|
|
|
|
return m.setState(state.Name(), nil)
|
|
}
|
|
|
|
func (m *Manager) setState(name string, state State) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if _, exists := m.states[name]; !exists {
|
|
return fmt.Errorf("state %s not registered", name)
|
|
}
|
|
|
|
m.states[name] = state
|
|
m.dirty[name] = struct{}{}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) periodicStateSave(ctx context.Context) {
|
|
ticker := time.NewTicker(10 * time.Second)
|
|
defer ticker.Stop()
|
|
defer close(m.done)
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := m.PersistState(ctx); err != nil {
|
|
log.Errorf("failed to persist state: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// PersistState persists the states that have been updated since the last save.
|
|
func (m *Manager) PersistState(ctx context.Context) error {
|
|
if m == nil {
|
|
return nil
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if len(m.dirty) == 0 {
|
|
return nil
|
|
}
|
|
|
|
bs, err := marshalWithPanicRecovery(m.states)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal states: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
|
|
done := make(chan error, 1)
|
|
start := time.Now()
|
|
go func() {
|
|
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case err := <-done:
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
|
|
|
|
clear(m.dirty)
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadStateFile reads and unmarshals the state file into a map of raw JSON messages
|
|
func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) {
|
|
data, err := os.ReadFile(m.filePath)
|
|
if err != nil {
|
|
if errors.Is(err, fs.ErrNotExist) {
|
|
log.Debug("state file does not exist")
|
|
return nil, nil // nolint:nilnil
|
|
}
|
|
return nil, fmt.Errorf("read state file: %w", err)
|
|
}
|
|
|
|
var rawStates map[string]json.RawMessage
|
|
if err := json.Unmarshal(data, &rawStates); err != nil {
|
|
log.Warn("State file appears to be corrupted, attempting to delete it")
|
|
if err := os.Remove(m.filePath); err != nil {
|
|
log.Errorf("Failed to delete corrupted state file: %v", err)
|
|
} else {
|
|
log.Info("State file deleted")
|
|
}
|
|
return nil, fmt.Errorf("unmarshal states: %w", err)
|
|
}
|
|
|
|
return rawStates, nil
|
|
}
|
|
|
|
// loadSingleRawState unmarshals a raw state into a concrete state object
|
|
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
|
|
stateType, ok := m.stateTypes[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("state %s not registered", name)
|
|
}
|
|
|
|
if string(rawState) == "null" {
|
|
return nil, nil //nolint:nilnil
|
|
}
|
|
|
|
statePtr := reflect.New(stateType).Interface().(State)
|
|
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
|
return nil, fmt.Errorf("unmarshal state %s: %w", name, err)
|
|
}
|
|
|
|
return statePtr, nil
|
|
}
|
|
|
|
// LoadState loads a specific state from the state file
|
|
func (m *Manager) LoadState(state State) error {
|
|
if m == nil {
|
|
return nil
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
rawStates, err := m.loadStateFile()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rawStates == nil {
|
|
return nil
|
|
}
|
|
|
|
name := state.Name()
|
|
rawState, exists := rawStates[name]
|
|
if !exists {
|
|
return nil
|
|
}
|
|
|
|
loadedState, err := m.loadSingleRawState(name, rawState)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
m.states[name] = loadedState
|
|
if loadedState != nil {
|
|
log.Debugf("loaded state: %s", name)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
|
|
// Unregistered states are preserved in their original state.
|
|
func (m *Manager) PerformCleanup() error {
|
|
if m == nil {
|
|
return nil
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Load raw states from file
|
|
rawStates, err := m.loadStateFile()
|
|
if err != nil {
|
|
log.Warnf("Failed to load state during cleanup: %v", err)
|
|
return err
|
|
}
|
|
if rawStates == nil {
|
|
return nil
|
|
}
|
|
|
|
var merr *multierror.Error
|
|
|
|
// Process each state in the file
|
|
for name, rawState := range rawStates {
|
|
// For unregistered states, preserve the raw JSON
|
|
if _, registered := m.stateTypes[name]; !registered {
|
|
m.states[name] = &RawState{data: rawState}
|
|
continue
|
|
}
|
|
|
|
// Load the registered state
|
|
loadedState, err := m.loadSingleRawState(name, rawState)
|
|
if err != nil {
|
|
merr = multierror.Append(merr, err)
|
|
continue
|
|
}
|
|
|
|
if loadedState == nil {
|
|
continue
|
|
}
|
|
|
|
// Check if state supports cleanup
|
|
cleanableState, isCleanable := loadedState.(CleanableState)
|
|
if !isCleanable {
|
|
// If it doesn't support cleanup, keep it as-is
|
|
m.states[name] = loadedState
|
|
continue
|
|
}
|
|
|
|
// Perform cleanup for cleanable states
|
|
log.Infof("client was not shut down properly, cleaning up %s", name)
|
|
if err := cleanableState.Cleanup(); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
|
|
// On cleanup error, preserve the state
|
|
m.states[name] = loadedState
|
|
} else {
|
|
// Successfully cleaned up - mark for deletion
|
|
m.states[name] = nil
|
|
m.dirty[name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
return nberrors.FormatErrorOrNil(merr)
|
|
}
|
|
|
|
func marshalWithPanicRecovery(v any) ([]byte, error) {
|
|
var bs []byte
|
|
var err error
|
|
|
|
func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = fmt.Errorf("panic during marshal: %v", r)
|
|
}
|
|
}()
|
|
bs, err = json.Marshal(v)
|
|
}()
|
|
|
|
return bs, err
|
|
}
|