[client] Fix state manager race conditions (#2890)

This commit is contained in:
Viktor Liu 2024-11-15 20:05:26 +01:00 committed by GitHub
parent a1c5287b7c
commit 121dfda915
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 118 additions and 100 deletions

View File

@ -7,7 +7,6 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
@ -323,13 +322,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err) log.Error(err)
} }
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
// don't block
go func() { go func() {
if err := s.stateManager.PersistState(ctx); err != nil { // persist dns state right away
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err) log.Errorf("Failed to persist dns state: %v", err)
} }
}() }()
@ -537,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
} }
// persist dns state right away go func() {
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) if err := s.stateManager.PersistState(s.ctx); err != nil {
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err) l.Errorf("Failed to persist dns state: %v", err)
} }
}()
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone() s.addHostRootZone()

View File

@ -297,7 +297,7 @@ func (e *Engine) Stop() error {
if err := e.stateManager.Stop(ctx); err != nil { if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err) return fmt.Errorf("failed to stop state manager: %w", err)
} }
if err := e.stateManager.PersistState(ctx); err != nil { if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }

View File

@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
type Counter[Key comparable, I, O any] struct { type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys // refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O] refCountMap map[Key]Ref[O]
refCountMu sync.Mutex mu sync.Mutex
// idMap keeps track of the keys associated with an ID for removal // idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key idMap map[string][]Key
idMu sync.Mutex
add AddFunc[Key, I, O] add AddFunc[Key, I, O]
remove RemoveFunc[Key, O] remove RemoveFunc[Key, O]
} }
@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
func (rm *Counter[Key, I, O]) LoadData( func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O], existingCounter *Counter[Key, I, O],
) { ) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.refCountMap = existingCounter.refCountMap rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap rm.idMap = existingCounter.idMap
@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
// Get retrieves the current reference count and associated data for a key. // Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false. // If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
ref, ok := rm.refCountMap[key] ref, ok := rm.refCountMap[key]
return ref, ok return ref, ok
@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
// Increment increments the reference count for the given key. // Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called. // If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
return rm.increment(key, in)
}
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
ref := rm.refCountMap[key] ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID. // IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called. // If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock() rm.mu.Lock()
defer rm.idMu.Unlock() defer rm.mu.Unlock()
ref, err := rm.Increment(key, in) ref, err := rm.increment(key, in)
if err != nil { if err != nil {
return ref, fmt.Errorf("with ID: %w", err) return ref, fmt.Errorf("with ID: %w", err)
} }
@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
// Decrement decrements the reference count for the given key. // Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called. // If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
return rm.decrement(key)
}
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
ref, ok := rm.refCountMap[key] ref, ok := rm.refCountMap[key]
if !ok { if !ok {
logCallerF("No reference found for key %v", key) logCallerF("No reference found for key %v", key)
@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
// DecrementWithID decrements the reference count for all keys associated with the given ID. // DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called. // If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock() rm.mu.Lock()
defer rm.idMu.Unlock() defer rm.mu.Unlock()
var merr *multierror.Error var merr *multierror.Error
for _, key := range rm.idMap[id] { for _, key := range rm.idMap[id] {
if _, err := rm.Decrement(key); err != nil { if _, err := rm.decrement(key); err != nil {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
} }
@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
// Flush removes all references and calls RemoveFunc for each key. // Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error { func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error var merr *multierror.Error
for key := range rm.refCountMap { for key := range rm.refCountMap {
@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
// Clear removes all references without calling RemoveFunc. // Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() { func (rm *Counter[Key, I, O]) Clear() {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
clear(rm.refCountMap) clear(rm.refCountMap)
clear(rm.idMap) clear(rm.idMap)
@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() {
// MarshalJSON implements the json.Marshaler interface for Counter. // MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.refCountMu.Lock() rm.mu.Lock()
defer rm.refCountMu.Unlock() defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
return json.Marshal(struct { return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"` RefCountMap map[Key]Ref[O] `json:"refCountMap"`

View File

@ -2,31 +2,28 @@ package systemops
import ( import (
"net/netip" "net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
) )
type ShutdownState struct { type ShutdownState ExclusionCounter
Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex
}
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
return "route_state" return "route_state"
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()
if s.Counter == nil {
return nil
}
sysops := NewSysOps(nil, nil) sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData(s.Counter) sysops.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush() return sysops.refCounter.Flush()
} }
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
return (*ExclusionCounter)(s).MarshalJSON()
}
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return (*ExclusionCounter)(s).UnmarshalJSON(data)
}

View File

@ -62,7 +62,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, err return nexthop, err
}, },
func(prefix netip.Prefix, nexthop Nexthop) error { func(prefix netip.Prefix, nexthop Nexthop) error {
// remove from state even if we have trouble removing it from the route table // update state even if we have trouble removing it from the route table
// it could be already gone // it could be already gone
r.updateState(stateManager) r.updateState(stateManager)
@ -75,12 +75,9 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return r.setupHooks(initAddresses) return r.setupHooks(initAddresses)
} }
// updateState updates state on every change so it will be persisted regularly
func (r *SysOps) updateState(stateManager *statemanager.Manager) { func (r *SysOps) updateState(stateManager *statemanager.Manager) {
state := getState(stateManager) if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
state.Counter = r.refCounter
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
} }
@ -532,14 +529,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes // Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix return isVpn, longestPrefix
} }
func getState(stateManager *statemanager.Manager) *ShutdownState {
var shutdownState *ShutdownState
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = &ShutdownState{}
}
return shutdownState
}

View File

@ -74,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if m.cancel != nil { if m.cancel == nil {
return nil
}
m.cancel() m.cancel()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-m.done: case <-m.done:
return nil
}
} }
return nil return nil
@ -179,14 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil return nil
} }
bs, err := marshalWithPanicRecovery(m.states)
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 5*time.Second) ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
done := make(chan error, 1) done := make(chan error, 1)
start := time.Now() start := time.Now()
go func() { go func() {
done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states) done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
}() }()
select { select {
@ -286,3 +290,19 @@ func (m *Manager) PerformCleanup() error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte
var err error
func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during marshal: %v", r)
}
}()
bs, err = json.Marshal(v)
}()
return bs, err
}

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -14,6 +15,19 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return fmt.Errorf("prepare config file dir: %w", err)
}
if err = EnforcePermission(file); err != nil {
return fmt.Errorf("enforce permission: %w", err)
}
return writeBytes(ctx, file, err, configDir, configFileName, bs)
}
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file) configDir, configFileName, err := prepareConfigFileDir(file)
@ -82,29 +96,44 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
// Check context before expensive operations // Check context before expensive operations
if ctx.Err() != nil { if ctx.Err() != nil {
return ctx.Err() return fmt.Errorf("write json start: %w", ctx.Err())
} }
// make it pretty // make it pretty
bs, err := json.MarshalIndent(obj, "", " ") bs, err := json.MarshalIndent(obj, "", " ")
if err != nil { if err != nil {
return err return fmt.Errorf("marshal: %w", err)
} }
return writeBytes(ctx, file, err, configDir, configFileName, bs)
}
func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error {
if ctx.Err() != nil { if ctx.Err() != nil {
return ctx.Err() return fmt.Errorf("write bytes start: %w", ctx.Err())
} }
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil { if err != nil {
return err return fmt.Errorf("create temp: %w", err)
} }
tempFileName := tempFile.Name() tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close() if deadline, ok := ctx.Deadline(); ok {
if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) {
log.Warnf("failed to set deadline: %v", err)
}
}
_, err = tempFile.Write(bs)
if err != nil { if err != nil {
return err _ = tempFile.Close()
return fmt.Errorf("write: %w", err)
}
if err = tempFile.Close(); err != nil {
return fmt.Errorf("close %s: %w", tempFileName, err)
} }
defer func() { defer func() {
@ -114,19 +143,13 @@ func writeJson(ctx context.Context, file string, obj interface{}, configDir stri
} }
}() }()
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}
// Check context again // Check context again
if ctx.Err() != nil { if ctx.Err() != nil {
return ctx.Err() return fmt.Errorf("after temp file: %w", ctx.Err())
} }
err = os.Rename(tempFileName, file) if err = os.Rename(tempFileName, file); err != nil {
if err != nil { return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
return err
} }
return nil return nil