[client] Cleanup dns and route states on startup (#2757)

This commit is contained in:
Viktor Liu 2024-10-24 10:53:46 +02:00 committed by GitHub
parent 44f2ce666e
commit 869537c951
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 786 additions and 377 deletions

View File

@ -0,0 +1 @@
package nftables

View File

@ -117,12 +117,6 @@ func (c *ConnectClient) run(
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
// Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
log.Errorf("checking unclean shutdown error: %s", err)
}
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@ -358,7 +352,11 @@ func (c *ConnectClient) Stop() error {
if c.engine == nil { if c.engine == nil {
return nil return nil
} }
return c.engine.Stop() if err := c.engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
return nil
} }
func (c *ConnectClient) isContextCancelled() bool { func (c *ConnectClient) isContextCancelled() bool {

View File

@ -1,6 +1,5 @@
package dns package dns
const ( const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
) )

View File

@ -3,6 +3,5 @@
package dns package dns
const ( const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
) )

View File

@ -9,6 +9,8 @@ import (
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
var ( var (
@ -20,7 +22,7 @@ var (
} }
) )
type repairConfFn func([]string, string, *resolvConf) error type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error
type repair struct { type repair struct {
operationFile string operationFile string
@ -40,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair {
} }
} }
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) { func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) {
if f.inotify != nil { if f.inotify != nil {
return return
} }
@ -81,7 +83,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err) log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
} }
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf) err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager)
if err != nil { if err != nil {
log.Errorf("failed to repair resolv.conf: %v", err) log.Errorf("failed to repair resolv.conf: %v", err)
} }

View File

@ -9,6 +9,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -104,14 +105,14 @@ nameserver 8.8.8.8`,
var changed bool var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf) error { updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
changed = true changed = true
cancel() cancel()
return nil return nil
} }
r := newRepair(operationFile, updateFn) r := newRepair(operationFile, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755) err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
if err != nil { if err != nil {
@ -151,14 +152,14 @@ searchdomain netbird.cloud something`
var changed bool var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf) error { updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
changed = true changed = true
cancel() cancel()
return nil return nil
} }
r := newRepair(tmpLink, updateFn) r := newRepair(tmpLink, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755) err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
if err != nil { if err != nil {

View File

@ -11,6 +11,8 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@ -36,7 +38,7 @@ type fileConfigurator struct {
nbNameserverIP string nbNameserverIP string
} }
func newFileConfigurator() (hostManager, error) { func newFileConfigurator() (*fileConfigurator, error) {
fc := &fileConfigurator{} fc := &fileConfigurator{}
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig) fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
return fc, nil return fc, nil
@ -46,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool {
return false return false
} }
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
backupFileExist := f.isBackupFileExist() backupFileExist := f.isBackupFileExist()
if !config.RouteAll { if !config.RouteAll {
if backupFileExist { if backupFileExist {
@ -76,15 +78,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
f.repair.stopWatchFileChanges() f.repair.stopWatchFileChanges()
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf) err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager)
if err != nil { if err != nil {
return err return err
} }
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP) f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager)
return nil return nil
} }
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error { func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error {
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
nameServers := generateNsList(nbNameserverIP, cfg) nameServers := generateNsList(nbNameserverIP, cfg)
@ -107,7 +109,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList) log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf // create another backup for unclean shutdown detection right after overwriting the original resolv.conf
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil { if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
} }
@ -145,10 +147,6 @@ func (f *fileConfigurator) restore() error {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return os.RemoveAll(fileDefaultResolvConfBackupLocation) return os.RemoveAll(fileDefaultResolvConfBackupLocation)
} }
@ -176,7 +174,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return restoreResolvConfFile() return restoreResolvConfFile()
} }
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring") log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress)
return nil return nil
} }
@ -192,10 +190,6 @@ func restoreResolvConfFile() error {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
}
return nil return nil
} }

View File

@ -5,14 +5,14 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
type hostManager interface { type hostManager interface {
applyDNSConfig(config HostDNSConfig) error applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNS() error restoreHostDNS() error
supportCustomPort() bool supportCustomPort() bool
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
} }
type SystemDNSSettings struct { type SystemDNSSettings struct {
@ -35,15 +35,15 @@ type DomainConfig struct {
} }
type mockHostConfigurator struct { type mockHostConfigurator struct {
applyDNSConfigFunc func(config HostDNSConfig) error applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNSFunc func() error restoreHostDNSFunc func() error
supportCustomPortFunc func() bool supportCustomPortFunc func() bool
restoreUncleanShutdownDNSFunc func(*netip.Addr) error restoreUncleanShutdownDNSFunc func(*netip.Addr) error
} }
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error { func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if m.applyDNSConfigFunc != nil { if m.applyDNSConfigFunc != nil {
return m.applyDNSConfigFunc(config) return m.applyDNSConfigFunc(config, stateManager)
} }
return fmt.Errorf("method applyDNSSettings is not implemented") return fmt.Errorf("method applyDNSSettings is not implemented")
} }
@ -62,16 +62,9 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
return false return false
} }
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
if m.restoreUncleanShutdownDNSFunc != nil {
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
}
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
}
func newNoopHostMocker() hostManager { func newNoopHostMocker() hostManager {
return &mockHostConfigurator{ return &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil }, applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil },
restoreHostDNSFunc: func() error { return nil }, restoreHostDNSFunc: func() error { return nil },
supportCustomPortFunc: func() bool { return true }, supportCustomPortFunc: func() bool { return true },
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil }, restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },

View File

@ -1,15 +1,17 @@
package dns package dns
import "net/netip" import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type androidHostManager struct { type androidHostManager struct {
} }
func newHostManager() (hostManager, error) { func newHostManager() (*androidHostManager, error) {
return &androidHostManager{}, nil return &androidHostManager{}, nil
} }
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error { func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
return nil return nil
} }
@ -20,7 +22,3 @@ func (a androidHostManager) restoreHostDNS() error {
func (a androidHostManager) supportCustomPort() bool { func (a androidHostManager) supportCustomPort() bool {
return false return false
} }
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@ -8,12 +8,13 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@ -37,7 +38,7 @@ type systemConfigurator struct {
systemDNSSettings SystemDNSSettings systemDNSSettings SystemDNSSettings
} }
func newHostManager() (hostManager, error) { func newHostManager() (*systemConfigurator, error) {
return &systemConfigurator{ return &systemConfigurator{
createdKeys: make(map[string]struct{}), createdKeys: make(map[string]struct{}),
}, nil }, nil
@ -47,12 +48,11 @@ func (s *systemConfigurator) supportCustomPort() bool {
return true return true
} }
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error var err error
// create a file for unclean shutdown detection if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
if err := createUncleanShutdownIndicator(); err != nil { log.Errorf("failed to update shutdown state: %s", err)
log.Errorf("failed to create unclean shutdown file: %s", err)
} }
var ( var (
@ -123,10 +123,6 @@ func (s *systemConfigurator) restoreHostDNS() error {
} }
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
}
return nil return nil
} }
@ -320,7 +316,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil return primaryService, router, nil
} }
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
if err := s.restoreHostDNS(); err != nil { if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err) return fmt.Errorf("restoring dns via scutil: %w", err)
} }

View File

@ -3,9 +3,10 @@ package dns
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/netip"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
type iosHostManager struct { type iosHostManager struct {
@ -13,13 +14,13 @@ type iosHostManager struct {
config HostDNSConfig config HostDNSConfig
} }
func newHostManager(dnsManager IosDnsManager) (hostManager, error) { func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
return &iosHostManager{ return &iosHostManager{
dnsManager: dnsManager, dnsManager: dnsManager,
}, nil }, nil
} }
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error { func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
jsonData, err := json.Marshal(config) jsonData, err := json.Marshal(config)
if err != nil { if err != nil {
return fmt.Errorf("marshal: %w", err) return fmt.Errorf("marshal: %w", err)
@ -37,7 +38,3 @@ func (a iosHostManager) restoreHostDNS() error {
func (a iosHostManager) supportCustomPort() bool { func (a iosHostManager) supportCustomPort() bool {
return false return false
} }
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@ -4,9 +4,9 @@ package dns
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io" "io"
"net/netip"
"os" "os"
"strings" "strings"
@ -21,27 +21,8 @@ const (
resolvConfManager resolvConfManager
) )
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
type osManagerType int type osManagerType int
func newOsManagerType(osManager string) (osManagerType, error) {
switch osManager {
case "netbird":
return fileManager, nil
case "file":
return netbirdManager, nil
case "networkManager":
return networkManager, nil
case "systemd":
return systemdManager, nil
case "resolvconf":
return resolvConfManager, nil
default:
return 0, ErrUnknownOsManagerType
}
}
func (t osManagerType) String() string { func (t osManagerType) String() string {
switch t { switch t {
case netbirdManager: case netbirdManager:
@ -59,6 +40,11 @@ func (t osManagerType) String() string {
} }
} }
type restoreHostManager interface {
hostManager
restoreUncleanShutdownDNS(*netip.Addr) error
}
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
@ -69,7 +55,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
return newHostManagerFromType(wgInterface, osManager) return newHostManagerFromType(wgInterface, osManager)
} }
func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) { func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {
switch osManager { switch osManager {
case networkManager: case networkManager:
return newNetworkManagerDbusConfigurator(wgInterface) return newNetworkManagerDbusConfigurator(wgInterface)

View File

@ -3,11 +3,12 @@ package dns
import ( import (
"fmt" "fmt"
"io" "io"
"net/netip"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@ -31,7 +32,7 @@ type registryConfigurator struct {
routingAll bool routingAll bool
} }
func newHostManager(wgInterface WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
guid, err := wgInterface.GetInterfaceGUIDString() guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil { if err != nil {
return nil, err return nil, err
@ -39,7 +40,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
return newHostManagerWithGuid(guid) return newHostManagerWithGuid(guid)
} }
func newHostManagerWithGuid(guid string) (hostManager, error) { func newHostManagerWithGuid(guid string) (*registryConfigurator, error) {
return &registryConfigurator{ return &registryConfigurator{
guid: guid, guid: guid,
}, nil }, nil
@ -49,7 +50,7 @@ func (r *registryConfigurator) supportCustomPort() bool {
return false return false
} }
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error var err error
if config.RouteAll { if config.RouteAll {
err = r.addDNSSetupForAll(config.ServerIP) err = r.addDNSSetupForAll(config.ServerIP)
@ -65,9 +66,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
// create a file for unclean shutdown detection if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil {
if err := createUncleanShutdownIndicator(r.guid); err != nil { log.Errorf("failed to update shutdown state: %s", err)
log.Errorf("failed to create unclean shutdown file: %s", err)
} }
var ( var (
@ -160,10 +160,6 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err) return fmt.Errorf("remove interface registry key: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
}
return nil return nil
} }
@ -221,7 +217,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
return regKey, nil return regKey, nil
} }
func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
if err := r.restoreHostDNS(); err != nil { if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via registry: %w", err) return fmt.Errorf("restoring dns via registry: %w", err)
} }

View File

@ -16,6 +16,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbversion "github.com/netbirdio/netbird/version" nbversion "github.com/netbirdio/netbird/version"
) )
@ -53,6 +54,7 @@ var supportedNetworkManagerVersionConstraints = []string{
type networkManagerDbusConfigurator struct { type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath dbusLinkObject dbus.ObjectPath
routingAll bool routingAll bool
ifaceName string
} }
// the types below are based on dbus specification, each field is mapped to a dbus type // the types below are based on dbus specification, each field is mapped to a dbus type
@ -77,7 +79,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
} }
} }
func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) { func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil { if err != nil {
return nil, fmt.Errorf("get nm dbus: %w", err) return nil, fmt.Errorf("get nm dbus: %w", err)
@ -93,6 +95,7 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error)
return &networkManagerDbusConfigurator{ return &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s), dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil }, nil
} }
@ -100,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
return false return false
} }
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
connSettings, configVersion, err := n.getAppliedConnectionSettings() connSettings, configVersion, err := n.getAppliedConnectionSettings()
if err != nil { if err != nil {
return fmt.Errorf("retrieving the applied connection settings, error: %w", err) return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
@ -151,10 +154,12 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. state := &ShutdownState{
// The file content itself is not important for network-manager restoration ManagerType: networkManager,
if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil { WgIface: n.ifaceName,
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) }
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
@ -171,10 +176,6 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
return fmt.Errorf("delete connection settings: %w", err) return fmt.Errorf("delete connection settings: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return nil return nil
} }

View File

@ -9,6 +9,8 @@ import (
"os/exec" "os/exec"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const resolvconfCommand = "resolvconf" const resolvconfCommand = "resolvconf"
@ -22,7 +24,7 @@ type resolvconf struct {
} }
// supported "openresolv" only // supported "openresolv" only
func newResolvConfConfigurator(wgInterface string) (hostManager, error) { func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
resolvConfEntries, err := parseDefaultResolvConf() resolvConfEntries, err := parseDefaultResolvConf()
if err != nil { if err != nil {
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err) log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
@ -40,7 +42,7 @@ func (r *resolvconf) supportCustomPort() bool {
return false return false
} }
func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error var err error
if !config.RouteAll { if !config.RouteAll {
err = r.restoreHostDNS() err = r.restoreHostDNS()
@ -60,9 +62,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
append([]string{config.ServerIP}, r.originalNameServers...), append([]string{config.ServerIP}, r.originalNameServers...),
options) options)
// create a backup for unclean shutdown detection before the resolv.conf is changed state := &ShutdownState{
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil { ManagerType: resolvConfManager,
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) WgIface: r.ifaceName,
}
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
err = r.applyConfig(buf) err = r.applyConfig(buf)
@ -79,11 +84,7 @@ func (r *resolvconf) restoreHostDNS() error {
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName) cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
_, err := cmd.Output() _, err := cmd.Output()
if err != nil { if err != nil {
return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err) return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
} }
return nil return nil
@ -95,7 +96,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
cmd.Stdin = &content cmd.Stdin = &content
_, err := cmd.Output() _, err := cmd.Output()
if err != nil { if err != nil {
return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err) return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)
} }
return nil return nil
} }

View File

@ -7,6 +7,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
@ -14,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@ -63,6 +65,7 @@ type DefaultServer struct {
iosDnsManager IosDnsManager iosDnsManager IosDnsManager
statusRecorder *peer.Status statusRecorder *peer.Status
stateManager *statemanager.Manager
} }
type handlerWithStop interface { type handlerWithStop interface {
@ -77,12 +80,7 @@ type muxUpdate struct {
} }
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer( func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
) (*DefaultServer, error) {
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@ -99,7 +97,7 @@ func NewDefaultServer(
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(wgInterface, addrPort)
} }
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
} }
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
@ -112,7 +110,7 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
ds.hostsDNSHolder.set(hostsDnsList) ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.addHostRootZone() ds.addHostRootZone()
@ -130,12 +128,12 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager, iosDnsManager IosDnsManager,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
ds.iosDnsManager = iosDnsManager ds.iosDnsManager = iosDnsManager
return ds return ds
} }
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer { func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
@ -147,6 +145,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
}, },
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
stateManager: stateManager,
hostsDNSHolder: newHostsDNSHolder(), hostsDNSHolder: newHostsDNSHolder(),
} }
@ -169,6 +168,7 @@ func (s *DefaultServer) Initialize() (err error) {
} }
} }
s.stateManager.RegisterState(&ShutdownState{})
s.hostManager, err = s.initialize() s.hostManager, err = s.initialize()
if err != nil { if err != nil {
return fmt.Errorf("initialize: %w", err) return fmt.Errorf("initialize: %w", err)
@ -191,9 +191,10 @@ func (s *DefaultServer) Stop() {
s.ctxCancel() s.ctxCancel()
if s.hostManager != nil { if s.hostManager != nil {
err := s.hostManager.restoreHostDNS() if err := s.hostManager.restoreHostDNS(); err != nil {
if err != nil { log.Error("failed to restore host DNS settings: ", err)
log.Error(err) } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil {
log.Errorf("failed to delete shutdown dns state: %v", err)
} }
} }
@ -318,10 +319,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
hostUpdate.RouteAll = false hostUpdate.RouteAll = false
} }
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil { if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
log.Error(err) log.Error(err)
} }
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
if s.searchDomainNotifier != nil { if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
} }
@ -521,7 +529,7 @@ func (s *DefaultServer) upstreamCallbacks(
} }
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
} }
@ -551,7 +559,7 @@ func (s *DefaultServer) upstreamCallbacks(
s.currentConfig.RouteAll = true s.currentConfig.RouteAll = true
s.service.RegisterMux(nbdns.RootZone, handler) s.service.RegisterMux(nbdns.RootZone, handler)
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
} }

View File

@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
@ -291,7 +292,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -400,7 +401,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return return
} }
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
if err != nil { if err != nil {
t.Errorf("create DNS server: %v", err) t.Errorf("create DNS server: %v", err)
return return
@ -495,7 +496,7 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}) dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
} }
@ -554,6 +555,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{} hostManager := &mockHostConfigurator{}
server := DefaultServer{ server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}), service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
@ -570,7 +572,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
} }
var domainsUpdate string var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error { hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{} domains := []string{}
for _, item := range config.Domains { for _, item := range config.Domains {
if item.Disabled { if item.Disabled {

View File

@ -1,5 +1,5 @@
package dns package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) { func (s *DefaultServer) initialize() (hostManager, error) {
return newHostManager(s.wgInterface) return newHostManager(s.wgInterface)
} }

View File

@ -7,7 +7,7 @@ import (
var errNotImplemented = errors.New("not implemented") var errNotImplemented = errors.New("not implemented")
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { func newSystemdDbusConfigurator(string) (restoreHostManager, error) {
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented) return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
} }

View File

@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@ -38,6 +39,7 @@ const (
type systemdDbusConfigurator struct { type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath dbusLinkObject dbus.ObjectPath
routingAll bool routingAll bool
ifaceName string
} }
// the types below are based on dbus specification, each field is mapped to a dbus type // the types below are based on dbus specification, each field is mapped to a dbus type
@ -55,7 +57,7 @@ type systemdDbusLinkDomainsInput struct {
MatchOnly bool MatchOnly bool
} }
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) {
iface, err := net.InterfaceByName(wgInterface) iface, err := net.InterfaceByName(wgInterface)
if err != nil { if err != nil {
return nil, fmt.Errorf("get interface: %w", err) return nil, fmt.Errorf("get interface: %w", err)
@ -77,6 +79,7 @@ func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
return &systemdDbusConfigurator{ return &systemdDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s), dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil }, nil
} }
@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
return true return true
} }
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
parsedIP, err := netip.ParseAddr(config.ServerIP) parsedIP, err := netip.ParseAddr(config.ServerIP)
if err != nil { if err != nil {
return fmt.Errorf("unable to parse ip address, error: %w", err) return fmt.Errorf("unable to parse ip address, error: %w", err)
@ -135,10 +138,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
} }
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. state := &ShutdownState{
// The file content itself is not important for systemd restoration ManagerType: systemdManager,
if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil { WgIface: s.ifaceName,
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) }
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
@ -174,10 +179,6 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
return fmt.Errorf("unable to revert link configuration, got error: %w", err) return fmt.Errorf("unable to revert link configuration, got error: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return s.flushCaches() return s.flushCaches()
} }

View File

@ -1,5 +0,0 @@
package dns
func CheckUncleanShutdown(string) error {
return nil
}

View File

@ -3,57 +3,25 @@
package dns package dns
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
) )
const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns" type ShutdownState struct {
}
func CheckUncleanShutdown(string) error { func (s *ShutdownState) Name() string {
if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil { return "dns_state"
if errors.Is(err, fs.ErrNotExist) { }
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation)
func (s *ShutdownState) Cleanup() error {
manager, err := newHostManager() manager, err := newHostManager()
if err != nil { if err != nil {
return fmt.Errorf("create host manager: %w", err) return fmt.Errorf("create host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(nil); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }
return nil return nil
} }
func createUncleanShutdownIndicator() error {
dir := filepath.Dir(fileUncleanShutdownFileLocation)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
}
if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec
return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err)
}
return nil
}

View File

@ -1,5 +0,0 @@
package dns
func CheckUncleanShutdown(string) error {
return nil
}

View File

@ -0,0 +1,14 @@
//go:build ios || android
package dns
type ShutdownState struct {
}
func (s *ShutdownState) Name() string {
return "dns_state"
}
func (s *ShutdownState) Cleanup() error {
return nil
}

View File

@ -3,66 +3,44 @@
package dns package dns
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"strings"
log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
func CheckUncleanShutdown(wgIface string) error { type ShutdownState struct {
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil { ManagerType osManagerType
if errors.Is(err, fs.ErrNotExist) { DNSAddress netip.Addr
// no file -> clean shutdown WgIface string
return nil }
} else {
return fmt.Errorf("state: %w", err)
}
}
log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation) func (s *ShutdownState) Name() string {
return "dns_state"
}
managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation) func (s *ShutdownState) Cleanup() error {
if err != nil { manager, err := newHostManagerFromType(s.WgIface, s.ManagerType)
return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err)
}
managerFields := strings.Split(string(managerData), ",")
if len(managerFields) < 2 {
return errors.New("split manager data: insufficient number of fields")
}
osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1]
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
if err != nil {
return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err)
}
log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr)
// determine os manager type, so we can invoke the respective restore action
osManagerType, err := newOsManagerType(osManagerTypeStr)
if err != nil {
return fmt.Errorf("detect previous host manager: %w", err)
}
manager, err := newHostManagerFromType(wgIface, osManagerType)
if err != nil { if err != nil {
return fmt.Errorf("create previous host manager: %w", err) return fmt.Errorf("create previous host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil { if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }
return nil return nil
} }
func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error { // TODO: move file contents to state manager
func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error {
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
if err != nil {
return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err)
}
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation) dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err) return fmt.Errorf("create dir %s: %w", dir, err)
@ -72,20 +50,13 @@ func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType
return fmt.Errorf("create %s: %w", sourcePath, err) return fmt.Errorf("create %s: %w", sourcePath, err)
} }
managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress) state := &ShutdownState{
ManagerType: fileManager,
if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec DNSAddress: dnsAddress,
return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err)
}
if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err)
} }
if err := stateManager.UpdateState(state); err != nil {
return fmt.Errorf("update state: %w", err)
}
return nil return nil
} }

View File

@ -1,75 +1,26 @@
package dns package dns
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"os"
"path/filepath"
"github.com/sirupsen/logrus"
) )
const ( type ShutdownState struct {
netbirdProgramDataLocation = "Netbird" Guid string
fileUncleanShutdownFile = "unclean_shutdown_dns.txt" }
)
func CheckUncleanShutdown(string) error { func (s *ShutdownState) Name() string {
file := getUncleanShutdownFile() return "dns_state"
}
if _, err := os.Stat(file); err != nil { func (s *ShutdownState) Cleanup() error {
if errors.Is(err, fs.ErrNotExist) { manager, err := newHostManagerWithGuid(s.Guid)
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file)
guid, err := os.ReadFile(file)
if err != nil {
return fmt.Errorf("read %s: %w", file, err)
}
manager, err := newHostManagerWithGuid(string(guid))
if err != nil { if err != nil {
return fmt.Errorf("create host manager: %w", err) return fmt.Errorf("create host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(nil); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }
return nil return nil
} }
func createUncleanShutdownIndicator(guid string) error {
file := getUncleanShutdownFile()
dir := filepath.Dir(file)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
}
if err := os.WriteFile(file, []byte(guid), 0600); err != nil {
return fmt.Errorf("create %s: %w", file, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
file := getUncleanShutdownFile()
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", file, err)
}
return nil
}
func getUncleanShutdownFile() string {
return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile)
}

View File

@ -23,18 +23,19 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@ -166,6 +167,7 @@ type Engine struct {
checks []*mgmProto.Checks checks []*mgmProto.Checks
relayManager *relayClient.Manager relayManager *relayClient.Manager
stateManager *statemanager.Manager
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@ -213,7 +215,7 @@ func NewEngineWithProbes(
probes *ProbeHolder, probes *ProbeHolder,
checks []*mgmProto.Checks, checks []*mgmProto.Checks,
) *Engine { ) *Engine {
return &Engine{ engine := &Engine{
clientCtx: clientCtx, clientCtx: clientCtx,
clientCancel: clientCancel, clientCancel: clientCancel,
signal: signalClient, signal: signalClient,
@ -232,6 +234,11 @@ func NewEngineWithProbes(
probes: probes, probes: probes,
checks: checks, checks: checks,
} }
if path := statemanager.GetDefaultStatePath(); path != "" {
engine.stateManager = statemanager.New(path)
}
return engine
} }
func (e *Engine) Stop() error { func (e *Engine) Stop() error {
@ -253,7 +260,7 @@ func (e *Engine) Stop() error {
e.stopDNSServer() e.stopDNSServer()
if e.routeManager != nil { if e.routeManager != nil {
e.routeManager.Stop() e.routeManager.Stop(e.stateManager)
} }
err := e.removeAllPeers() err := e.removeAllPeers()
@ -275,6 +282,17 @@ func (e *Engine) Stop() error {
e.close() e.close()
log.Infof("stopped Netbird Engine") log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
}
if err := e.stateManager.PersistState(ctx); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil return nil
} }
@ -314,6 +332,8 @@ func (e *Engine) Start() error {
} }
} }
e.stateManager.Start()
initialRoutes, dnsServer, err := e.newDnsServer() initialRoutes, dnsServer, err := e.newDnsServer()
if err != nil { if err != nil {
e.close() e.close()
@ -322,7 +342,7 @@ func (e *Engine) Start() error {
e.dnsServer = dnsServer e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init() beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
if err != nil { if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
} else { } else {
@ -1219,10 +1239,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder) dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
return nil, dnsServer, nil return nil, dnsServer, nil
default: default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder) dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return nil, dnsServer, nil return nil, dnsServer, nil
} }
} }

View File

@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@ -31,14 +32,14 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector GetRouteSelector() *routeselector.RouteSelector
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error EnableServerRouter(firewall firewall.Manager) error
Stop() Stop(stateManager *statemanager.Manager)
} }
// DefaultManager is the default instance of a route manager // DefaultManager is the default instance of a route manager
@ -120,12 +121,12 @@ func NewManager(
} }
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
if nbnet.CustomRoutingDisabled() { if nbnet.CustomRoutingDisabled() {
return nil, nil, nil return nil, nil, nil
} }
if err := m.sysOps.CleanupRouting(); err != nil { if err := m.sysOps.CleanupRouting(nil); err != nil {
log.Warnf("Failed cleaning up routing: %v", err) log.Warnf("Failed cleaning up routing: %v", err)
} }
@ -136,7 +137,7 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
ips := resolveURLsToIPs(initialAddresses) ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips) beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err) return nil, nil, fmt.Errorf("setup routing: %w", err)
} }
@ -154,7 +155,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
} }
// Stop stops the manager watchers and clean firewall rules // Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() { func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop() m.stop()
if m.serverRouter != nil { if m.serverRouter != nil {
m.serverRouter.cleanUp() m.serverRouter.cleanUp()
@ -172,7 +173,7 @@ func (m *DefaultManager) Stop() {
} }
if !nbnet.CustomRoutingDisabled() { if !nbnet.CustomRoutingDisabled() {
if err := m.sysOps.CleanupRouting(); err != nil { if err := m.sysOps.CleanupRouting(stateManager); err != nil {
log.Errorf("Error cleaning up routing: %v", err) log.Errorf("Error cleaning up routing: %v", err)
} else { } else {
log.Info("Routing cleanup complete") log.Info("Routing cleanup complete")

View File

@ -426,10 +426,10 @@ func TestManagerUpdateRoutes(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
_, _, err = routeManager.Init() _, _, err = routeManager.Init(nil)
require.NoError(t, err, "should init route manager") require.NoError(t, err, "should init route manager")
defer routeManager.Stop() defer routeManager.Stop(nil)
if testCase.removeSrvRouter { if testCase.removeSrvRouter {
routeManager.serverRouter = nil routeManager.serverRouter = nil

View File

@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/util/net"
) )
@ -17,10 +18,10 @@ type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelectionFunc func(haMap route.HAMap) TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector GetRouteSelectorFunc func() *routeselector.RouteSelector
StopFunc func() StopFunc func(manager *statemanager.Manager)
} }
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) {
return nil, nil, nil return nil, nil, nil
} }
@ -65,8 +66,8 @@ func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
} }
// Stop mock implementation of Stop from Manager interface // Stop mock implementation of Stop from Manager interface
func (m *MockManager) Stop() { func (m *MockManager) Stop(stateManager *statemanager.Manager) {
if m.StopFunc != nil { if m.StopFunc != nil {
m.StopFunc() m.StopFunc(stateManager)
} }
} }

View File

@ -0,0 +1,81 @@
package systemops
import (
"encoding/json"
"fmt"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
)
type RouteEntry struct {
Prefix netip.Prefix `json:"prefix"`
Nexthop Nexthop `json:"nexthop"`
}
type ShutdownState struct {
Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"`
mu sync.RWMutex
}
func NewShutdownState() *ShutdownState {
return &ShutdownState{
Routes: make(map[netip.Prefix]RouteEntry),
}
}
func (s *ShutdownState) Name() string {
return "route_state"
}
func (s *ShutdownState) Cleanup() error {
sysops := NewSysOps(nil, nil)
var merr *multierror.Error
s.mu.RLock()
defer s.mu.RUnlock()
for _, route := range s.Routes {
if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) {
s.mu.Lock()
defer s.mu.Unlock()
s.Routes[prefix] = RouteEntry{
Prefix: prefix,
Nexthop: nexthop,
}
}
func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.Routes, prefix)
}
// MarshalJSON ensures that empty routes are marshaled as null
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.Routes) == 0 {
return json.Marshal(nil)
}
return json.Marshal(s.Routes)
}
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &s.Routes)
}

View File

@ -9,14 +9,15 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return nil, nil, nil return nil, nil, nil
} }
func (r *SysOps) CleanupRouting() error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
return nil return nil
} }
@ -28,6 +29,10 @@ func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
return nil return nil
} }
func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
return nil
}
func EnableIPForwarding() error { func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil return nil

View File

@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@ -30,7 +31,9 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate") var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
stateManager.RegisterState(&ShutdownState{})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err) log.Errorf("Unable to get initial v4 default next hop: %v", err)
@ -53,9 +56,18 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
// These errors are not critical, but also we should not track and try to remove the routes either. // These errors are not critical, but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore return nexthop, refcounter.ErrIgnore
} }
r.updateState(stateManager, prefix, nexthop)
return nexthop, err return nexthop, err
}, },
r.removeFromRouteTable, func(prefix netip.Prefix, nexthop Nexthop) error {
// remove from state even if we have trouble removing it from the route table
// it could be already gone
r.removeFromState(stateManager, prefix)
return r.removeFromRouteTable(prefix, nexthop)
},
) )
r.refCounter = refCounter r.refCounter = refCounter
@ -63,7 +75,25 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
return r.setupHooks(initAddresses) return r.setupHooks(initAddresses)
} }
func (r *SysOps) cleanupRefCounter() error { func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) {
state := getState(stateManager)
state.UpdateRoute(prefix, nexthop)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) {
state := getState(stateManager)
state.RemoveRoute(prefix)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("Failed to update state: %v", err)
}
}
func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
if r.refCounter == nil { if r.refCounter == nil {
return nil return nil
} }
@ -76,6 +106,10 @@ func (r *SysOps) cleanupRefCounter() error {
return fmt.Errorf("flush route manager: %w", err) return fmt.Errorf("flush route manager: %w", err)
} }
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
log.Errorf("failed to delete state: %v", err)
}
return nil return nil
} }
@ -506,3 +540,14 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes // Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix return isVpn, longestPrefix
} }
func getState(stateManager *statemanager.Manager) *ShutdownState {
var shutdownState *ShutdownState
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = NewShutdownState()
}
return shutdownState
}

View File

@ -77,10 +77,10 @@ func TestAddRemoveRoutes(t *testing.T) {
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err = r.SetupRouting(nil) _, _, err = r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting()) assert.NoError(t, r.CleanupRouting(nil))
}) })
index, err := net.InterfaceByName(wgInterface.Name()) index, err := net.InterfaceByName(wgInterface.Name())
@ -403,10 +403,10 @@ func setupTestEnv(t *testing.T) {
}) })
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil) _, _, err := r.SetupRouting(nil, nil)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting()) assert.NoError(t, r.CleanupRouting(nil))
}) })
index, err := net.InterfaceByName(wgInterface.Name()) index, err := net.InterfaceByName(wgInterface.Name())

View File

@ -9,17 +9,18 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{}) r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil return nil, nil, nil
} }
func (r *SysOps) CleanupRouting() error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
@ -46,6 +47,18 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
return nil return nil
} }
func (r *SysOps) notify() {
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
for prefix := range r.prefixes {
prefixes = append(prefixes, prefix)
}
r.notifier.OnNewPrefixes(prefixes)
}
func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
return nil
}
func EnableIPForwarding() error { func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil return nil
@ -54,11 +67,3 @@ func EnableIPForwarding() error {
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) { func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{} return false, netip.Prefix{}
} }
func (r *SysOps) notify() {
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
for prefix := range r.prefixes {
prefixes = append(prefixes, prefix)
}
r.notifier.OnNewPrefixes(prefixes)
}

View File

@ -18,6 +18,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@ -85,10 +86,10 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
if isLegacy() { if isLegacy() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses) return r.setupRefCounter(initAddresses, stateManager)
} }
if err = addRoutingTableName(); err != nil { if err = addRoutingTableName(); err != nil {
@ -104,7 +105,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
defer func() { defer func() {
if err != nil { if err != nil {
if cleanErr := r.CleanupRouting(); cleanErr != nil { if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr) log.Errorf("Error cleaning up routing: %v", cleanErr)
} }
} }
@ -116,7 +117,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
if errors.Is(err, syscall.EOPNOTSUPP) { if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
setIsLegacy(true) setIsLegacy(true)
return r.setupRefCounter(initAddresses) return r.setupRefCounter(initAddresses, stateManager)
} }
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
} }
@ -128,9 +129,9 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state. // It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process. // The function uses error aggregation to report any errors encountered during the cleanup process.
func (r *SysOps) CleanupRouting() error { func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
if isLegacy() { if isLegacy() {
return r.cleanupRefCounter() return r.cleanupRefCounter(stateManager)
} }
var result *multierror.Error var result *multierror.Error

View File

@ -13,15 +13,16 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return r.setupRefCounter(initAddresses) return r.setupRefCounter(initAddresses, stateManager)
} }
func (r *SysOps) CleanupRouting() error { func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
return r.cleanupRefCounter() return r.cleanupRefCounter(stateManager)
} }
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {

View File

@ -22,6 +22,7 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@ -130,12 +131,12 @@ const (
RouteDeleted RouteDeleted
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return r.setupRefCounter(initAddresses) return r.setupRefCounter(initAddresses, stateManager)
} }
func (r *SysOps) CleanupRouting() error { func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
return r.cleanupRefCounter() return r.cleanupRefCounter(stateManager)
} }
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {

View File

@ -0,0 +1,298 @@
package statemanager
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"reflect"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// State interface defines the methods that all state types must implement
type State interface {
Name() string
Cleanup() error
}
// Manager handles the persistence and management of various states
type Manager struct {
mu sync.Mutex
cancel context.CancelFunc
done chan struct{}
filePath string
// holds the states that are registered with the manager and that are to be persisted
states map[string]State
// holds the state names that have been updated and need to be persisted with the next save
dirty map[string]struct{}
// holds the type information for each registered state
stateTypes map[string]reflect.Type
}
// New creates a new Manager instance
func New(filePath string) *Manager {
return &Manager{
filePath: filePath,
states: make(map[string]State),
dirty: make(map[string]struct{}),
stateTypes: make(map[string]reflect.Type),
}
}
// Start starts the state manager periodic save routine
func (m *Manager) Start() {
if m == nil {
return
}
m.mu.Lock()
defer m.mu.Unlock()
var ctx context.Context
ctx, m.cancel = context.WithCancel(context.Background())
m.done = make(chan struct{})
go m.periodicStateSave(ctx)
}
func (m *Manager) Stop(ctx context.Context) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel != nil {
m.cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
}
return nil
}
// RegisterState registers a state with the manager but doesn't attempt to persist it.
// Pass an uninitialized state to register it.
func (m *Manager) RegisterState(state State) {
if m == nil {
return
}
m.mu.Lock()
defer m.mu.Unlock()
name := state.Name()
m.states[name] = nil
m.stateTypes[name] = reflect.TypeOf(state).Elem()
}
// GetState returns the state for the given type
func (m *Manager) GetState(state State) State {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return m.states[state.Name()]
}
// UpdateState updates the state in the manager and marks it as dirty for the next save.
// The state will be replaced with the new one.
func (m *Manager) UpdateState(state State) error {
if m == nil {
return nil
}
return m.setState(state.Name(), state)
}
// DeleteState removes the state from the manager and marks it as dirty for the next save.
// Pass an uninitialized state to delete it.
func (m *Manager) DeleteState(state State) error {
if m == nil {
return nil
}
return m.setState(state.Name(), nil)
}
func (m *Manager) setState(name string, state State) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.states[name]; !exists {
return fmt.Errorf("state %s not registered", name)
}
m.states[name] = state
m.dirty[name] = struct{}{}
return nil
}
func (m *Manager) periodicStateSave(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
defer close(m.done)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := m.PersistState(ctx); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}
}
}
// PersistState persists the states that have been updated since the last save.
func (m *Manager) PersistState(ctx context.Context) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
if len(m.dirty) == 0 {
return nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
data, err := json.MarshalIndent(m.states, "", " ")
if err != nil {
done <- fmt.Errorf("marshal states: %w", err)
return
}
// nolint:gosec
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
done <- fmt.Errorf("write state file: %w", err)
return
}
done <- nil
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
if err != nil {
return err
}
}
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
clear(m.dirty)
return nil
}
// loadState loads the existing state from the state file
func (m *Manager) loadState() error {
data, err := os.ReadFile(m.filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist")
return nil
}
return fmt.Errorf("read state file: %w", err)
}
var rawStates map[string]json.RawMessage
if err := json.Unmarshal(data, &rawStates); err != nil {
log.Warn("State file appears to be corrupted, attempting to delete it")
if err := os.Remove(m.filePath); err != nil {
log.Errorf("Failed to delete corrupted state file: %v", err)
} else {
log.Info("State file deleted")
}
return fmt.Errorf("unmarshal states: %w", err)
}
var merr *multierror.Error
for name, rawState := range rawStates {
stateType, ok := m.stateTypes[name]
if !ok {
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name))
continue
}
if string(rawState) == "null" {
continue
}
statePtr := reflect.New(stateType).Interface().(State)
if err := json.Unmarshal(rawState, statePtr); err != nil {
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err))
continue
}
m.states[name] = statePtr
log.Debugf("loaded state: %s", name)
}
return nberrors.FormatErrorOrNil(merr)
}
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them.
// If the cleanup is successful, the state is marked for deletion.
func (m *Manager) PerformCleanup() error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
if err := m.loadState(); err != nil {
log.Warnf("Failed to load state during cleanup: %v", err)
}
var merr *multierror.Error
for name, state := range m.states {
if state == nil {
// If no state was found in the state file, we don't mark the state dirty nor return an error
continue
}
log.Infof("client was not shut down properly, cleaning up %s", name)
if err := state.Cleanup(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
} else {
// mark for deletion on cleanup success
m.states[name] = nil
m.dirty[name] = struct{}{}
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@ -0,0 +1,35 @@
package statemanager
import (
"os"
"path/filepath"
"runtime"
"github.com/sirupsen/logrus"
)
// GetDefaultStatePath returns the path to the state file based on the operating system
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
func GetDefaultStatePath() string {
var path string
switch runtime.GOOS {
case "windows":
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux":
path = "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly":
path = "/var/db/netbird/state.json"
// ios/android don't need state
default:
return ""
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
return ""
}
return path
}

View File

@ -138,12 +138,12 @@ func (c *Client) Stop() {
c.ctxCancel() c.ctxCancel()
} }
// ÏSetTraceLogLevel configure the logger to trace level // SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() { func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)
} }
// getStatusDetails return with the list of the PeerInfos // GetStatusDetails return with the list of the PeerInfos
func (c *Client) GetStatusDetails() *StatusDetails { func (c *Client) GetStatusDetails() *StatusDetails {
fullStatus := c.recorder.GetFullStatus() fullStatus := c.recorder.GetFullStatus()

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
@ -20,7 +21,11 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
@ -39,6 +44,8 @@ const (
defaultMaxRetryInterval = 60 * time.Minute defaultMaxRetryInterval = 60 * time.Minute
defaultMaxRetryTime = 14 * 24 * time.Hour defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7 defaultRetryMultiplier = 1.7
errRestoreResidualState = "failed to restore residual state: %v"
) )
// Server for service control. // Server for service control.
@ -95,6 +102,10 @@ func (s *Server) Start() error {
defer s.mutex.Unlock() defer s.mutex.Unlock()
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
if err := restoreResidualState(s.rootCtx); err != nil {
log.Warnf(errRestoreResidualState, err)
}
// if current state contains any error, return it // if current state contains any error, return it
// in all other cases we can continue execution only if status is idle and up command was // in all other cases we can continue execution only if status is idle and up command was
// not in the progress or already successfully established connection. // not in the progress or already successfully established connection.
@ -292,6 +303,10 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.actCancel = cancel s.actCancel = cancel
s.mutex.Unlock() s.mutex.Unlock()
if err := restoreResidualState(ctx); err != nil {
log.Warnf(errRestoreResidualState, err)
}
state := internal.CtxGetState(ctx) state := internal.CtxGetState(ctx)
defer func() { defer func() {
status, err := state.Status() status, err := state.Status()
@ -549,6 +564,10 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if err := restoreResidualState(callerCtx); err != nil {
log.Warnf(errRestoreResidualState, err)
}
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
// if current state contains any error, return it // if current state contains any error, return it
@ -829,3 +848,31 @@ func sendTerminalNotification() error {
return wallCmd.Wait() return wallCmd.Wait()
} }
// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
func restoreResidualState(ctx context.Context) error {
path := statemanager.GetDefaultStatePath()
if path == "" {
return nil
}
mgr := statemanager.New(path)
var merr *multierror.Error
// register the states we are interested in restoring
// this will also allow each subsystem to record its own state
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})
if err := mgr.PerformCleanup(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err))
}
if err := mgr.PersistState(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}