From 39329e12a1f54cc1770a010279fa4902ccd51693 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 13 Nov 2024 13:46:00 +0100 Subject: [PATCH] [client] Improve state write timeout and abort work early on timeout (#2882) * Improve state write timeout and abort work early on timeout * Don't block on initial persist state --- client/firewall/iptables/manager_linux.go | 8 +++++--- client/firewall/nftables/manager_linux.go | 8 +++++--- client/internal/config.go | 6 +++--- client/internal/dns/server.go | 10 +++++++--- client/internal/statemanager/manager.go | 20 +++++--------------- client/internal/statemanager/path.go | 22 +++++----------------- management/server/file_store.go | 2 +- util/file.go | 23 ++++++++++++++++++----- util/file_test.go | 7 ++++--- 9 files changed, 53 insertions(+), 53 deletions(-) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index a59bd2c60..adb8f20ef 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } // persist early to ensure cleanup of chains - if err := stateManager.PersistState(context.Background()); err != nil { - log.Errorf("failed to persist state: %v", err) - } + go func() { + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + }() return nil } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index ea8912f27..3f8fac249 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } // persist early - if err := stateManager.PersistState(context.Background()); err != nil { - log.Errorf("failed to persist state: %v", err) - } + go func() { + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + }() return nil } diff --git a/client/internal/config.go b/client/internal/config.go index ee54c6380..ce87835cd 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { if err != nil { return nil, err } - err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) + err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg) return cfg, err } @@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) { // WriteOutConfig write put the prepared config to the given path func WriteOutConfig(path string, config *Config) error { - return util.WriteJson(path, config) + return util.WriteJson(context.Background(), path, config) } // createNewConfig creates a new config generating a new Wireguard key and saving to file @@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) { } if updated { - if err := util.WriteJson(input.ConfigPath, config); err != nil { + if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil { return nil, err } } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 929e1e60c..6c4dccae7 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -326,9 +326,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // persist dns state right away ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - log.Errorf("Failed to persist dns state: %v", err) - } + + // don't block + go func() { + if err := s.stateManager.PersistState(ctx); err != nil { + log.Errorf("Failed to persist dns state: %v", err) + } + }() if s.searchDomainNotifier != nil { s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index a5a14f807..580ccdfc7 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -16,6 +16,7 @@ import ( "golang.org/x/exp/maps" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/util" ) // State interface defines the methods that all state types must implement @@ -178,25 +179,14 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) + start := time.Now() go func() { - data, err := json.MarshalIndent(m.states, "", " ") - if err != nil { - done <- fmt.Errorf("marshal states: %w", err) - return - } - - // nolint:gosec - if err := os.WriteFile(m.filePath, data, 0640); err != nil { - done <- fmt.Errorf("write state file: %w", err) - return - } - - done <- nil + done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states) }() select { @@ -208,7 +198,7 @@ func (m *Manager) PersistState(ctx context.Context) error { } } - log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) + log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start)) clear(m.dirty) diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go index 96d6a9f12..6cfd79a12 100644 --- a/client/internal/statemanager/path.go +++ b/client/internal/statemanager/path.go @@ -4,32 +4,20 @@ import ( "os" "path/filepath" "runtime" - - log "github.com/sirupsen/logrus" ) // GetDefaultStatePath returns the path to the state file based on the operating system -// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. +// It returns an empty string if the path cannot be determined. func GetDefaultStatePath() string { - var path string - switch runtime.GOOS { case "windows": - path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") + return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") case "darwin", "linux": - path = "/var/lib/netbird/state.json" + return "/var/lib/netbird/state.json" case "freebsd", "openbsd", "netbsd", "dragonfly": - path = "/var/db/netbird/state.json" - // ios/android don't need state - default: - return "" + return "/var/db/netbird/state.json" } - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0755); err != nil { - log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) - return "" - } + return "" - return path } diff --git a/management/server/file_store.go b/management/server/file_store.go index 561e133ce..f375fb990 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { // It is recommended to call it with locking FileStore.mux func (s *FileStore) persist(ctx context.Context, file string) error { start := time.Now() - err := util.WriteJson(file, s) + err := util.WriteJson(context.Background(), file, s) if err != nil { return err } diff --git a/util/file.go b/util/file.go index ecaecd222..4641cc1b8 100644 --- a/util/file.go +++ b/util/file.go @@ -15,7 +15,7 @@ import ( ) // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory -func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { +func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err @@ -26,18 +26,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { return err } - return writeJson(file, obj, configDir, configFileName) + return writeJson(ctx, file, obj, configDir, configFileName) } // WriteJson writes JSON config object to a file creating parent directories if required // The output JSON is pretty-formatted -func WriteJson(file string, obj interface{}) error { +func WriteJson(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err } - return writeJson(file, obj, configDir, configFileName) + return writeJson(ctx, file, obj, configDir, configFileName) } // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file @@ -79,7 +79,11 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { return nil } -func writeJson(file string, obj interface{}, configDir string, configFileName string) error { +func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { + // Check context before expensive operations + if ctx.Err() != nil { + return ctx.Err() + } // make it pretty bs, err := json.MarshalIndent(obj, "", " ") @@ -87,6 +91,10 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st return err } + if ctx.Err() != nil { + return ctx.Err() + } + tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) if err != nil { return err @@ -111,6 +119,11 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st return err } + // Check context again + if ctx.Err() != nil { + return ctx.Err() + } + err = os.Rename(tempFileName, file) if err != nil { return err diff --git a/util/file_test.go b/util/file_test.go index 566d8eda6..f8c9dfabb 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,6 +1,7 @@ package util import ( + "context" "crypto/md5" "encoding/hex" "io" @@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tmpDir := t.TempDir() - err := WriteJson(tmpDir+"/testconfig.json", tt.config) + err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config) require.NoError(t, err) read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) @@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) { src := tmpDir + "/copytest_src" dst := tmpDir + "/copytest_dst" - err := WriteJson(src, tt.srcContent) + err := WriteJson(context.Background(), src, tt.srcContent) require.NoError(t, err) err = CopyFileContents(src, dst) @@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) { _ = os.Remove(cfgFile) }() - err := WriteJson(cfgFile, tt.config) + err := WriteJson(context.Background(), cfgFile, tt.config) require.NoError(t, err) read, err := ReadJson(cfgFile, &TestConfig{})