mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-21 23:53:14 +01:00
[client] Cleanup dns and route states on startup (#2757)
This commit is contained in:
parent
44f2ce666e
commit
869537c951
1
client/firewall/nftables/state.go
Normal file
1
client/firewall/nftables/state.go
Normal file
@ -0,0 +1 @@
|
|||||||
|
package nftables
|
@ -117,12 +117,6 @@ func (c *ConnectClient) run(
|
|||||||
|
|
||||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||||
|
|
||||||
// Check if client was not shut down in a clean way and restore DNS config if required.
|
|
||||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
|
||||||
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
|
|
||||||
log.Errorf("checking unclean shutdown error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@ -358,7 +352,11 @@ func (c *ConnectClient) Stop() error {
|
|||||||
if c.engine == nil {
|
if c.engine == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return c.engine.Stop()
|
if err := c.engine.Stop(); err != nil {
|
||||||
|
return fmt.Errorf("stop engine: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) isContextCancelled() bool {
|
func (c *ConnectClient) isContextCancelled() bool {
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
const (
|
const (
|
||||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
||||||
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
|
|
||||||
)
|
)
|
||||||
|
@ -3,6 +3,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
const (
|
const (
|
||||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
|
||||||
)
|
)
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -20,7 +22,7 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type repairConfFn func([]string, string, *resolvConf) error
|
type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error
|
||||||
|
|
||||||
type repair struct {
|
type repair struct {
|
||||||
operationFile string
|
operationFile string
|
||||||
@ -40,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) {
|
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) {
|
||||||
if f.inotify != nil {
|
if f.inotify != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -81,7 +83,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin
|
|||||||
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
|
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf)
|
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to repair resolv.conf: %v", err)
|
log.Errorf("failed to repair resolv.conf: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -104,14 +105,14 @@ nameserver 8.8.8.8`,
|
|||||||
|
|
||||||
var changed bool
|
var changed bool
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
updateFn := func([]string, string, *resolvConf) error {
|
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
|
||||||
changed = true
|
changed = true
|
||||||
cancel()
|
cancel()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
r := newRepair(operationFile, updateFn)
|
r := newRepair(operationFile, updateFn)
|
||||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
|
||||||
|
|
||||||
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
|
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -151,14 +152,14 @@ searchdomain netbird.cloud something`
|
|||||||
|
|
||||||
var changed bool
|
var changed bool
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
updateFn := func([]string, string, *resolvConf) error {
|
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
|
||||||
changed = true
|
changed = true
|
||||||
cancel()
|
cancel()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
r := newRepair(tmpLink, updateFn)
|
r := newRepair(tmpLink, updateFn)
|
||||||
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
|
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
|
||||||
|
|
||||||
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
|
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -11,6 +11,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -36,7 +38,7 @@ type fileConfigurator struct {
|
|||||||
nbNameserverIP string
|
nbNameserverIP string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFileConfigurator() (hostManager, error) {
|
func newFileConfigurator() (*fileConfigurator, error) {
|
||||||
fc := &fileConfigurator{}
|
fc := &fileConfigurator{}
|
||||||
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
|
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
|
||||||
return fc, nil
|
return fc, nil
|
||||||
@ -46,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
backupFileExist := f.isBackupFileExist()
|
backupFileExist := f.isBackupFileExist()
|
||||||
if !config.RouteAll {
|
if !config.RouteAll {
|
||||||
if backupFileExist {
|
if backupFileExist {
|
||||||
@ -76,15 +78,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
|
|
||||||
f.repair.stopWatchFileChanges()
|
f.repair.stopWatchFileChanges()
|
||||||
|
|
||||||
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
|
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
|
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
|
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error {
|
||||||
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
|
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
|
||||||
nameServers := generateNsList(nbNameserverIP, cfg)
|
nameServers := generateNsList(nbNameserverIP, cfg)
|
||||||
|
|
||||||
@ -107,7 +109,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
|
|||||||
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
|
||||||
|
|
||||||
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
|
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
|
||||||
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
|
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil {
|
||||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,10 +147,6 @@ func (f *fileConfigurator) restore() error {
|
|||||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,7 +174,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
|
|||||||
return restoreResolvConfFile()
|
return restoreResolvConfFile()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
|
log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,10 +190,6 @@ func restoreResolvConfFile() error {
|
|||||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
|
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,14 +5,14 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
applyDNSConfig(config HostDNSConfig) error
|
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||||
restoreHostDNS() error
|
restoreHostDNS() error
|
||||||
supportCustomPort() bool
|
supportCustomPort() bool
|
||||||
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type SystemDNSSettings struct {
|
type SystemDNSSettings struct {
|
||||||
@ -35,15 +35,15 @@ type DomainConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mockHostConfigurator struct {
|
type mockHostConfigurator struct {
|
||||||
applyDNSConfigFunc func(config HostDNSConfig) error
|
applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||||
restoreHostDNSFunc func() error
|
restoreHostDNSFunc func() error
|
||||||
supportCustomPortFunc func() bool
|
supportCustomPortFunc func() bool
|
||||||
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
|
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
if m.applyDNSConfigFunc != nil {
|
if m.applyDNSConfigFunc != nil {
|
||||||
return m.applyDNSConfigFunc(config)
|
return m.applyDNSConfigFunc(config, stateManager)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("method applyDNSSettings is not implemented")
|
return fmt.Errorf("method applyDNSSettings is not implemented")
|
||||||
}
|
}
|
||||||
@ -62,16 +62,9 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
|
||||||
if m.restoreUncleanShutdownDNSFunc != nil {
|
|
||||||
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func newNoopHostMocker() hostManager {
|
func newNoopHostMocker() hostManager {
|
||||||
return &mockHostConfigurator{
|
return &mockHostConfigurator{
|
||||||
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
|
applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil },
|
||||||
restoreHostDNSFunc: func() error { return nil },
|
restoreHostDNSFunc: func() error { return nil },
|
||||||
supportCustomPortFunc: func() bool { return true },
|
supportCustomPortFunc: func() bool { return true },
|
||||||
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },
|
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import "net/netip"
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
type androidHostManager struct {
|
type androidHostManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager() (hostManager, error) {
|
func newHostManager() (*androidHostManager, error) {
|
||||||
return &androidHostManager{}, nil
|
return &androidHostManager{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
|
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20,7 +22,3 @@ func (a androidHostManager) restoreHostDNS() error {
|
|||||||
func (a androidHostManager) supportCustomPort() bool {
|
func (a androidHostManager) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -8,12 +8,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -37,7 +38,7 @@ type systemConfigurator struct {
|
|||||||
systemDNSSettings SystemDNSSettings
|
systemDNSSettings SystemDNSSettings
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager() (hostManager, error) {
|
func newHostManager() (*systemConfigurator, error) {
|
||||||
return &systemConfigurator{
|
return &systemConfigurator{
|
||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}, nil
|
}, nil
|
||||||
@ -47,12 +48,11 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// create a file for unclean shutdown detection
|
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
|
||||||
if err := createUncleanShutdownIndicator(); err != nil {
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -123,10 +123,6 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -320,7 +316,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
|||||||
return primaryService, router, nil
|
return primaryService, router, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
|
||||||
if err := s.restoreHostDNS(); err != nil {
|
if err := s.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via scutil: %w", err)
|
return fmt.Errorf("restoring dns via scutil: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,10 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
type iosHostManager struct {
|
type iosHostManager struct {
|
||||||
@ -13,13 +14,13 @@ type iosHostManager struct {
|
|||||||
config HostDNSConfig
|
config HostDNSConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
|
func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
|
||||||
return &iosHostManager{
|
return &iosHostManager{
|
||||||
dnsManager: dnsManager,
|
dnsManager: dnsManager,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
|
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
jsonData, err := json.Marshal(config)
|
jsonData, err := json.Marshal(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal: %w", err)
|
return fmt.Errorf("marshal: %w", err)
|
||||||
@ -37,7 +38,3 @@ func (a iosHostManager) restoreHostDNS() error {
|
|||||||
func (a iosHostManager) supportCustomPort() bool {
|
func (a iosHostManager) supportCustomPort() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -4,9 +4,9 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -21,27 +21,8 @@ const (
|
|||||||
resolvConfManager
|
resolvConfManager
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
|
|
||||||
|
|
||||||
type osManagerType int
|
type osManagerType int
|
||||||
|
|
||||||
func newOsManagerType(osManager string) (osManagerType, error) {
|
|
||||||
switch osManager {
|
|
||||||
case "netbird":
|
|
||||||
return fileManager, nil
|
|
||||||
case "file":
|
|
||||||
return netbirdManager, nil
|
|
||||||
case "networkManager":
|
|
||||||
return networkManager, nil
|
|
||||||
case "systemd":
|
|
||||||
return systemdManager, nil
|
|
||||||
case "resolvconf":
|
|
||||||
return resolvConfManager, nil
|
|
||||||
default:
|
|
||||||
return 0, ErrUnknownOsManagerType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t osManagerType) String() string {
|
func (t osManagerType) String() string {
|
||||||
switch t {
|
switch t {
|
||||||
case netbirdManager:
|
case netbirdManager:
|
||||||
@ -59,6 +40,11 @@ func (t osManagerType) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type restoreHostManager interface {
|
||||||
|
hostManager
|
||||||
|
restoreUncleanShutdownDNS(*netip.Addr) error
|
||||||
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -69,7 +55,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
|
|||||||
return newHostManagerFromType(wgInterface, osManager)
|
return newHostManagerFromType(wgInterface, osManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) {
|
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {
|
||||||
switch osManager {
|
switch osManager {
|
||||||
case networkManager:
|
case networkManager:
|
||||||
return newNetworkManagerDbusConfigurator(wgInterface)
|
return newNetworkManagerDbusConfigurator(wgInterface)
|
||||||
|
@ -3,11 +3,12 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -31,7 +32,7 @@ type registryConfigurator struct {
|
|||||||
routingAll bool
|
routingAll bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -39,7 +40,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
|
|||||||
return newHostManagerWithGuid(guid)
|
return newHostManagerWithGuid(guid)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManagerWithGuid(guid string) (hostManager, error) {
|
func newHostManagerWithGuid(guid string) (*registryConfigurator, error) {
|
||||||
return ®istryConfigurator{
|
return ®istryConfigurator{
|
||||||
guid: guid,
|
guid: guid,
|
||||||
}, nil
|
}, nil
|
||||||
@ -49,7 +50,7 @@ func (r *registryConfigurator) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
var err error
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
err = r.addDNSSetupForAll(config.ServerIP)
|
err = r.addDNSSetupForAll(config.ServerIP)
|
||||||
@ -65,9 +66,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a file for unclean shutdown detection
|
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil {
|
||||||
if err := createUncleanShutdownIndicator(r.guid); err != nil {
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -160,10 +160,6 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("remove interface registry key: %w", err)
|
return fmt.Errorf("remove interface registry key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown file: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,7 +217,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
|
|||||||
return regKey, nil
|
return regKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
|
||||||
if err := r.restoreHostDNS(); err != nil {
|
if err := r.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via registry: %w", err)
|
return fmt.Errorf("restoring dns via registry: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbversion "github.com/netbirdio/netbird/version"
|
nbversion "github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,6 +54,7 @@ var supportedNetworkManagerVersionConstraints = []string{
|
|||||||
type networkManagerDbusConfigurator struct {
|
type networkManagerDbusConfigurator struct {
|
||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
routingAll bool
|
routingAll bool
|
||||||
|
ifaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||||
@ -77,7 +79,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) {
|
func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) {
|
||||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get nm dbus: %w", err)
|
return nil, fmt.Errorf("get nm dbus: %w", err)
|
||||||
@ -93,6 +95,7 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error)
|
|||||||
|
|
||||||
return &networkManagerDbusConfigurator{
|
return &networkManagerDbusConfigurator{
|
||||||
dbusLinkObject: dbus.ObjectPath(s),
|
dbusLinkObject: dbus.ObjectPath(s),
|
||||||
|
ifaceName: wgInterface,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
connSettings, configVersion, err := n.getAppliedConnectionSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
|
return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
|
||||||
@ -151,10 +154,12 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
|
|||||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
|
||||||
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
|
||||||
|
|
||||||
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
state := &ShutdownState{
|
||||||
// The file content itself is not important for network-manager restoration
|
ManagerType: networkManager,
|
||||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil {
|
WgIface: n.ifaceName,
|
||||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
}
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||||
@ -171,10 +176,6 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("delete connection settings: %w", err)
|
return fmt.Errorf("delete connection settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const resolvconfCommand = "resolvconf"
|
const resolvconfCommand = "resolvconf"
|
||||||
@ -22,7 +24,7 @@ type resolvconf struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// supported "openresolv" only
|
// supported "openresolv" only
|
||||||
func newResolvConfConfigurator(wgInterface string) (hostManager, error) {
|
func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
|
||||||
resolvConfEntries, err := parseDefaultResolvConf()
|
resolvConfEntries, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
|
||||||
@ -40,7 +42,7 @@ func (r *resolvconf) supportCustomPort() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
var err error
|
||||||
if !config.RouteAll {
|
if !config.RouteAll {
|
||||||
err = r.restoreHostDNS()
|
err = r.restoreHostDNS()
|
||||||
@ -60,9 +62,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
append([]string{config.ServerIP}, r.originalNameServers...),
|
append([]string{config.ServerIP}, r.originalNameServers...),
|
||||||
options)
|
options)
|
||||||
|
|
||||||
// create a backup for unclean shutdown detection before the resolv.conf is changed
|
state := &ShutdownState{
|
||||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
|
ManagerType: resolvConfManager,
|
||||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
WgIface: r.ifaceName,
|
||||||
|
}
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.applyConfig(buf)
|
err = r.applyConfig(buf)
|
||||||
@ -79,11 +84,7 @@ func (r *resolvconf) restoreHostDNS() error {
|
|||||||
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
|
||||||
_, err := cmd.Output()
|
_, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
||||||
}
|
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -95,7 +96,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
|||||||
cmd.Stdin = &content
|
cmd.Stdin = &content
|
||||||
_, err := cmd.Output()
|
_, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
|
return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
@ -14,6 +15,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -63,6 +65,7 @@ type DefaultServer struct {
|
|||||||
iosDnsManager IosDnsManager
|
iosDnsManager IosDnsManager
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerWithStop interface {
|
type handlerWithStop interface {
|
||||||
@ -77,12 +80,7 @@ type muxUpdate struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(
|
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
|
||||||
ctx context.Context,
|
|
||||||
wgInterface WGIface,
|
|
||||||
customAddress string,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
) (*DefaultServer, error) {
|
|
||||||
var addrPort *netip.AddrPort
|
var addrPort *netip.AddrPort
|
||||||
if customAddress != "" {
|
if customAddress != "" {
|
||||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||||
@ -99,7 +97,7 @@ func NewDefaultServer(
|
|||||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
|
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||||
@ -112,7 +110,7 @@ func NewDefaultServerPermanentUpstream(
|
|||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
||||||
ds.hostsDNSHolder.set(hostsDnsList)
|
ds.hostsDNSHolder.set(hostsDnsList)
|
||||||
ds.permanent = true
|
ds.permanent = true
|
||||||
ds.addHostRootZone()
|
ds.addHostRootZone()
|
||||||
@ -130,12 +128,12 @@ func NewDefaultServerIos(
|
|||||||
iosDnsManager IosDnsManager,
|
iosDnsManager IosDnsManager,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
||||||
ds.iosDnsManager = iosDnsManager
|
ds.iosDnsManager = iosDnsManager
|
||||||
return ds
|
return ds
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
|
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@ -147,6 +145,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
|
|||||||
},
|
},
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
stateManager: stateManager,
|
||||||
hostsDNSHolder: newHostsDNSHolder(),
|
hostsDNSHolder: newHostsDNSHolder(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,6 +168,7 @@ func (s *DefaultServer) Initialize() (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.stateManager.RegisterState(&ShutdownState{})
|
||||||
s.hostManager, err = s.initialize()
|
s.hostManager, err = s.initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("initialize: %w", err)
|
return fmt.Errorf("initialize: %w", err)
|
||||||
@ -191,9 +191,10 @@ func (s *DefaultServer) Stop() {
|
|||||||
s.ctxCancel()
|
s.ctxCancel()
|
||||||
|
|
||||||
if s.hostManager != nil {
|
if s.hostManager != nil {
|
||||||
err := s.hostManager.restoreHostDNS()
|
if err := s.hostManager.restoreHostDNS(); err != nil {
|
||||||
if err != nil {
|
log.Error("failed to restore host DNS settings: ", err)
|
||||||
log.Error(err)
|
} else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
|
log.Errorf("failed to delete shutdown dns state: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -318,10 +319,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
hostUpdate.RouteAll = false
|
hostUpdate.RouteAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
|
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// persist dns state right away
|
||||||
|
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||||
|
log.Errorf("Failed to persist dns state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if s.searchDomainNotifier != nil {
|
if s.searchDomainNotifier != nil {
|
||||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||||
}
|
}
|
||||||
@ -521,7 +529,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -551,7 +559,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
s.currentConfig.RouteAll = true
|
s.currentConfig.RouteAll = true
|
||||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||||
}
|
}
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
|
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
@ -291,7 +292,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
t.Log(err)
|
t.Log(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -400,7 +401,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
|
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create DNS server: %v", err)
|
t.Errorf("create DNS server: %v", err)
|
||||||
return
|
return
|
||||||
@ -495,7 +496,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
|
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v", err)
|
t.Fatalf("%v", err)
|
||||||
}
|
}
|
||||||
@ -554,6 +555,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||||
hostManager := &mockHostConfigurator{}
|
hostManager := &mockHostConfigurator{}
|
||||||
server := DefaultServer{
|
server := DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
service: NewServiceViaMemory(&mocWGIface{}),
|
service: NewServiceViaMemory(&mocWGIface{}),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
@ -570,7 +572,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var domainsUpdate string
|
var domainsUpdate string
|
||||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error {
|
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
||||||
domains := []string{}
|
domains := []string{}
|
||||||
for _, item := range config.Domains {
|
for _, item := range config.Domains {
|
||||||
if item.Disabled {
|
if item.Disabled {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
func (s *DefaultServer) initialize() (hostManager, error) {
|
||||||
return newHostManager(s.wgInterface)
|
return newHostManager(s.wgInterface)
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
var errNotImplemented = errors.New("not implemented")
|
var errNotImplemented = errors.New("not implemented")
|
||||||
|
|
||||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
func newSystemdDbusConfigurator(string) (restoreHostManager, error) {
|
||||||
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
|
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -38,6 +39,7 @@ const (
|
|||||||
type systemdDbusConfigurator struct {
|
type systemdDbusConfigurator struct {
|
||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
routingAll bool
|
routingAll bool
|
||||||
|
ifaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||||
@ -55,7 +57,7 @@ type systemdDbusLinkDomainsInput struct {
|
|||||||
MatchOnly bool
|
MatchOnly bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) {
|
||||||
iface, err := net.InterfaceByName(wgInterface)
|
iface, err := net.InterfaceByName(wgInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get interface: %w", err)
|
return nil, fmt.Errorf("get interface: %w", err)
|
||||||
@ -77,6 +79,7 @@ func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
|||||||
|
|
||||||
return &systemdDbusConfigurator{
|
return &systemdDbusConfigurator{
|
||||||
dbusLinkObject: dbus.ObjectPath(s),
|
dbusLinkObject: dbus.ObjectPath(s),
|
||||||
|
ifaceName: wgInterface,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
parsedIP, err := netip.ParseAddr(config.ServerIP)
|
parsedIP, err := netip.ParseAddr(config.ServerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to parse ip address, error: %w", err)
|
return fmt.Errorf("unable to parse ip address, error: %w", err)
|
||||||
@ -135,10 +138,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
|||||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
|
state := &ShutdownState{
|
||||||
// The file content itself is not important for systemd restoration
|
ManagerType: systemdManager,
|
||||||
if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil {
|
WgIface: s.ifaceName,
|
||||||
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
|
}
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
|
||||||
@ -174,10 +179,6 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
|
||||||
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.flushCaches()
|
return s.flushCaches()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
func CheckUncleanShutdown(string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -3,57 +3,25 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns"
|
type ShutdownState struct {
|
||||||
|
}
|
||||||
|
|
||||||
func CheckUncleanShutdown(string) error {
|
func (s *ShutdownState) Name() string {
|
||||||
if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil {
|
return "dns_state"
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
}
|
||||||
// no file -> clean shutdown
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("state: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation)
|
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
manager, err := newHostManager()
|
manager, err := newHostManager()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create host manager: %w", err)
|
return fmt.Errorf("create host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUncleanShutdownIndicator() error {
|
|
||||||
dir := filepath.Dir(fileUncleanShutdownFileLocation)
|
|
||||||
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
|
||||||
return fmt.Errorf("create dir %s: %w", dir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec
|
|
||||||
return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeUncleanShutdownIndicator() error {
|
|
||||||
if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
func CheckUncleanShutdown(string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
14
client/internal/dns/unclean_shutdown_mobile.go
Normal file
14
client/internal/dns/unclean_shutdown_mobile.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
//go:build ios || android
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
type ShutdownState struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "dns_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
return nil
|
||||||
|
}
|
@ -3,66 +3,44 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CheckUncleanShutdown(wgIface string) error {
|
type ShutdownState struct {
|
||||||
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
|
ManagerType osManagerType
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
DNSAddress netip.Addr
|
||||||
// no file -> clean shutdown
|
WgIface string
|
||||||
return nil
|
}
|
||||||
} else {
|
|
||||||
return fmt.Errorf("state: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation)
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "dns_state"
|
||||||
|
}
|
||||||
|
|
||||||
managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation)
|
func (s *ShutdownState) Cleanup() error {
|
||||||
if err != nil {
|
manager, err := newHostManagerFromType(s.WgIface, s.ManagerType)
|
||||||
return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
managerFields := strings.Split(string(managerData), ",")
|
|
||||||
if len(managerFields) < 2 {
|
|
||||||
return errors.New("split manager data: insufficient number of fields")
|
|
||||||
}
|
|
||||||
osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1]
|
|
||||||
|
|
||||||
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr)
|
|
||||||
|
|
||||||
// determine os manager type, so we can invoke the respective restore action
|
|
||||||
osManagerType, err := newOsManagerType(osManagerTypeStr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("detect previous host manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := newHostManagerFromType(wgIface, osManagerType)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create previous host manager: %w", err)
|
return fmt.Errorf("create previous host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error {
|
// TODO: move file contents to state manager
|
||||||
|
func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error {
|
||||||
|
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
|
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
|
||||||
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
||||||
return fmt.Errorf("create dir %s: %w", dir, err)
|
return fmt.Errorf("create dir %s: %w", dir, err)
|
||||||
@ -72,20 +50,13 @@ func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType
|
|||||||
return fmt.Errorf("create %s: %w", sourcePath, err)
|
return fmt.Errorf("create %s: %w", sourcePath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress)
|
state := &ShutdownState{
|
||||||
|
ManagerType: fileManager,
|
||||||
if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec
|
DNSAddress: dnsAddress,
|
||||||
return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeUncleanShutdownIndicator() error {
|
|
||||||
if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err)
|
|
||||||
}
|
|
||||||
if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err)
|
|
||||||
}
|
}
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
return fmt.Errorf("update state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,75 +1,26 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
type ShutdownState struct {
|
||||||
netbirdProgramDataLocation = "Netbird"
|
Guid string
|
||||||
fileUncleanShutdownFile = "unclean_shutdown_dns.txt"
|
}
|
||||||
)
|
|
||||||
|
|
||||||
func CheckUncleanShutdown(string) error {
|
func (s *ShutdownState) Name() string {
|
||||||
file := getUncleanShutdownFile()
|
return "dns_state"
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := os.Stat(file); err != nil {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
manager, err := newHostManagerWithGuid(s.Guid)
|
||||||
// no file -> clean shutdown
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("state: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file)
|
|
||||||
|
|
||||||
guid, err := os.ReadFile(file)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read %s: %w", file, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := newHostManagerWithGuid(string(guid))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create host manager: %w", err)
|
return fmt.Errorf("create host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown backup: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUncleanShutdownIndicator(guid string) error {
|
|
||||||
file := getUncleanShutdownFile()
|
|
||||||
|
|
||||||
dir := filepath.Dir(file)
|
|
||||||
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
|
|
||||||
return fmt.Errorf("create dir %s: %w", dir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(file, []byte(guid), 0600); err != nil {
|
|
||||||
return fmt.Errorf("create %s: %w", file, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeUncleanShutdownIndicator() error {
|
|
||||||
file := getUncleanShutdownFile()
|
|
||||||
|
|
||||||
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
return fmt.Errorf("remove %s: %w", file, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUncleanShutdownFile() string {
|
|
||||||
return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile)
|
|
||||||
}
|
|
||||||
|
@ -23,18 +23,19 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@ -166,6 +167,7 @@ type Engine struct {
|
|||||||
checks []*mgmProto.Checks
|
checks []*mgmProto.Checks
|
||||||
|
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@ -213,7 +215,7 @@ func NewEngineWithProbes(
|
|||||||
probes *ProbeHolder,
|
probes *ProbeHolder,
|
||||||
checks []*mgmProto.Checks,
|
checks []*mgmProto.Checks,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
return &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
clientCancel: clientCancel,
|
clientCancel: clientCancel,
|
||||||
signal: signalClient,
|
signal: signalClient,
|
||||||
@ -232,6 +234,11 @@ func NewEngineWithProbes(
|
|||||||
probes: probes,
|
probes: probes,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
}
|
}
|
||||||
|
if path := statemanager.GetDefaultStatePath(); path != "" {
|
||||||
|
engine.stateManager = statemanager.New(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return engine
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) Stop() error {
|
func (e *Engine) Stop() error {
|
||||||
@ -253,7 +260,7 @@ func (e *Engine) Stop() error {
|
|||||||
e.stopDNSServer()
|
e.stopDNSServer()
|
||||||
|
|
||||||
if e.routeManager != nil {
|
if e.routeManager != nil {
|
||||||
e.routeManager.Stop()
|
e.routeManager.Stop(e.stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
@ -275,6 +282,17 @@ func (e *Engine) Stop() error {
|
|||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
log.Infof("stopped Netbird Engine")
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := e.stateManager.Stop(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||||
|
}
|
||||||
|
if err := e.stateManager.PersistState(ctx); err != nil {
|
||||||
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,6 +332,8 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.stateManager.Start()
|
||||||
|
|
||||||
initialRoutes, dnsServer, err := e.newDnsServer()
|
initialRoutes, dnsServer, err := e.newDnsServer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.close()
|
e.close()
|
||||||
@ -322,7 +342,7 @@ func (e *Engine) Start() error {
|
|||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to initialize route manager: %s", err)
|
log.Errorf("Failed to initialize route manager: %s", err)
|
||||||
} else {
|
} else {
|
||||||
@ -1219,10 +1239,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
|
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
|
||||||
return nil, dnsServer, nil
|
return nil, dnsServer, nil
|
||||||
default:
|
default:
|
||||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder)
|
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, dnsServer, nil
|
return nil, dnsServer, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
@ -31,14 +32,14 @@ import (
|
|||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||||
TriggerSelection(route.HAMap)
|
TriggerSelection(route.HAMap)
|
||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
EnableServerRouter(firewall firewall.Manager) error
|
EnableServerRouter(firewall firewall.Manager) error
|
||||||
Stop()
|
Stop(stateManager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultManager is the default instance of a route manager
|
// DefaultManager is the default instance of a route manager
|
||||||
@ -120,12 +121,12 @@ func NewManager(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Init sets up the routing
|
// Init sets up the routing
|
||||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
if nbnet.CustomRoutingDisabled() {
|
if nbnet.CustomRoutingDisabled() {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.sysOps.CleanupRouting(); err != nil {
|
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
||||||
log.Warnf("Failed cleaning up routing: %v", err)
|
log.Warnf("Failed cleaning up routing: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -136,7 +137,7 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
|||||||
|
|
||||||
ips := resolveURLsToIPs(initialAddresses)
|
ips := resolveURLsToIPs(initialAddresses)
|
||||||
|
|
||||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
|
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||||
}
|
}
|
||||||
@ -154,7 +155,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the manager watchers and clean firewall rules
|
// Stop stops the manager watchers and clean firewall rules
|
||||||
func (m *DefaultManager) Stop() {
|
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||||
m.stop()
|
m.stop()
|
||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
m.serverRouter.cleanUp()
|
m.serverRouter.cleanUp()
|
||||||
@ -172,7 +173,7 @@ func (m *DefaultManager) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !nbnet.CustomRoutingDisabled() {
|
if !nbnet.CustomRoutingDisabled() {
|
||||||
if err := m.sysOps.CleanupRouting(); err != nil {
|
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
||||||
log.Errorf("Error cleaning up routing: %v", err)
|
log.Errorf("Error cleaning up routing: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Info("Routing cleanup complete")
|
log.Info("Routing cleanup complete")
|
||||||
|
@ -426,10 +426,10 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
|
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
|
||||||
|
|
||||||
_, _, err = routeManager.Init()
|
_, _, err = routeManager.Init(nil)
|
||||||
|
|
||||||
require.NoError(t, err, "should init route manager")
|
require.NoError(t, err, "should init route manager")
|
||||||
defer routeManager.Stop()
|
defer routeManager.Stop(nil)
|
||||||
|
|
||||||
if testCase.removeSrvRouter {
|
if testCase.removeSrvRouter {
|
||||||
routeManager.serverRouter = nil
|
routeManager.serverRouter = nil
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/util/net"
|
"github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
@ -17,10 +18,10 @@ type MockManager struct {
|
|||||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||||
TriggerSelectionFunc func(haMap route.HAMap)
|
TriggerSelectionFunc func(haMap route.HAMap)
|
||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
StopFunc func()
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,8 +66,8 @@ func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Stop mock implementation of Stop from Manager interface
|
// Stop mock implementation of Stop from Manager interface
|
||||||
func (m *MockManager) Stop() {
|
func (m *MockManager) Stop(stateManager *statemanager.Manager) {
|
||||||
if m.StopFunc != nil {
|
if m.StopFunc != nil {
|
||||||
m.StopFunc()
|
m.StopFunc(stateManager)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
81
client/internal/routemanager/systemops/state.go
Normal file
81
client/internal/routemanager/systemops/state.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RouteEntry struct {
|
||||||
|
Prefix netip.Prefix `json:"prefix"`
|
||||||
|
Nexthop Nexthop `json:"nexthop"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ShutdownState struct {
|
||||||
|
Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"`
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewShutdownState() *ShutdownState {
|
||||||
|
return &ShutdownState{
|
||||||
|
Routes: make(map[netip.Prefix]RouteEntry),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "route_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
sysops := NewSysOps(nil, nil)
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, route := range s.Routes {
|
||||||
|
if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.Routes[prefix] = RouteEntry{
|
||||||
|
Prefix: prefix,
|
||||||
|
Nexthop: nexthop,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
delete(s.Routes, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON ensures that empty routes are marshaled as null
|
||||||
|
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
if len(s.Routes) == 0 {
|
||||||
|
return json.Marshal(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(s.Routes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
|
||||||
|
return json.Unmarshal(data, &s.Routes)
|
||||||
|
}
|
@ -9,14 +9,15 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting() error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,6 +29,10 @@ func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func EnableIPForwarding() error {
|
func EnableIPForwarding() error {
|
||||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
return nil
|
return nil
|
||||||
|
@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,7 +31,9 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
|||||||
|
|
||||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||||
|
|
||||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||||
@ -53,9 +56,18 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
|
|||||||
// These errors are not critical, but also we should not track and try to remove the routes either.
|
// These errors are not critical, but also we should not track and try to remove the routes either.
|
||||||
return nexthop, refcounter.ErrIgnore
|
return nexthop, refcounter.ErrIgnore
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState(stateManager, prefix, nexthop)
|
||||||
|
|
||||||
return nexthop, err
|
return nexthop, err
|
||||||
},
|
},
|
||||||
r.removeFromRouteTable,
|
func(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
|
// remove from state even if we have trouble removing it from the route table
|
||||||
|
// it could be already gone
|
||||||
|
r.removeFromState(stateManager, prefix)
|
||||||
|
|
||||||
|
return r.removeFromRouteTable(prefix, nexthop)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
r.refCounter = refCounter
|
r.refCounter = refCounter
|
||||||
@ -63,7 +75,25 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
|
|||||||
return r.setupHooks(initAddresses)
|
return r.setupHooks(initAddresses)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) cleanupRefCounter() error {
|
func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) {
|
||||||
|
state := getState(stateManager)
|
||||||
|
state.UpdateRoute(prefix, nexthop)
|
||||||
|
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) {
|
||||||
|
state := getState(stateManager)
|
||||||
|
state.RemoveRoute(prefix)
|
||||||
|
|
||||||
|
if err := stateManager.UpdateState(state); err != nil {
|
||||||
|
log.Errorf("Failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
||||||
if r.refCounter == nil {
|
if r.refCounter == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -76,6 +106,10 @@ func (r *SysOps) cleanupRefCounter() error {
|
|||||||
return fmt.Errorf("flush route manager: %w", err)
|
return fmt.Errorf("flush route manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
|
log.Errorf("failed to delete state: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -506,3 +540,14 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
|
|||||||
// Return true if the longest matching prefix is from vpnRoutes
|
// Return true if the longest matching prefix is from vpnRoutes
|
||||||
return isVpn, longestPrefix
|
return isVpn, longestPrefix
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getState(stateManager *statemanager.Manager) *ShutdownState {
|
||||||
|
var shutdownState *ShutdownState
|
||||||
|
if state := stateManager.GetState(shutdownState); state != nil {
|
||||||
|
shutdownState = state.(*ShutdownState)
|
||||||
|
} else {
|
||||||
|
shutdownState = NewShutdownState()
|
||||||
|
}
|
||||||
|
|
||||||
|
return shutdownState
|
||||||
|
}
|
||||||
|
@ -77,10 +77,10 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
|
|
||||||
_, _, err = r.SetupRouting(nil)
|
_, _, err = r.SetupRouting(nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting())
|
assert.NoError(t, r.CleanupRouting(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
index, err := net.InterfaceByName(wgInterface.Name())
|
index, err := net.InterfaceByName(wgInterface.Name())
|
||||||
@ -403,10 +403,10 @@ func setupTestEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
_, _, err := r.SetupRouting(nil)
|
_, _, err := r.SetupRouting(nil, nil)
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting())
|
assert.NoError(t, r.CleanupRouting(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
index, err := net.InterfaceByName(wgInterface.Name())
|
index, err := net.InterfaceByName(wgInterface.Name())
|
||||||
|
@ -9,17 +9,18 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
r.prefixes = make(map[netip.Prefix]struct{})
|
r.prefixes = make(map[netip.Prefix]struct{})
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting() error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
@ -46,6 +47,18 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) notify() {
|
||||||
|
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
|
||||||
|
for prefix := range r.prefixes {
|
||||||
|
prefixes = append(prefixes, prefix)
|
||||||
|
}
|
||||||
|
r.notifier.OnNewPrefixes(prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func EnableIPForwarding() error {
|
func EnableIPForwarding() error {
|
||||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
return nil
|
return nil
|
||||||
@ -54,11 +67,3 @@ func EnableIPForwarding() error {
|
|||||||
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
|
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
|
||||||
return false, netip.Prefix{}
|
return false, netip.Prefix{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) notify() {
|
|
||||||
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
|
|
||||||
for prefix := range r.prefixes {
|
|
||||||
prefixes = append(prefixes, prefix)
|
|
||||||
}
|
|
||||||
r.notifier.OnNewPrefixes(prefixes)
|
|
||||||
}
|
|
||||||
|
@ -18,6 +18,7 @@ import (
|
|||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -85,10 +86,10 @@ func getSetupRules() []ruleParams {
|
|||||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||||
// This table is where a default route or other specific routes received from the management server are configured,
|
// This table is where a default route or other specific routes received from the management server are configured,
|
||||||
// enabling VPN connectivity.
|
// enabling VPN connectivity.
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
||||||
if isLegacy() {
|
if isLegacy() {
|
||||||
log.Infof("Using legacy routing setup")
|
log.Infof("Using legacy routing setup")
|
||||||
return r.setupRefCounter(initAddresses)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = addRoutingTableName(); err != nil {
|
if err = addRoutingTableName(); err != nil {
|
||||||
@ -104,7 +105,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
|
|||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cleanErr := r.CleanupRouting(); cleanErr != nil {
|
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -116,7 +117,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
|
|||||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||||
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
||||||
setIsLegacy(true)
|
setIsLegacy(true)
|
||||||
return r.setupRefCounter(initAddresses)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||||
}
|
}
|
||||||
@ -128,9 +129,9 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb
|
|||||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||||
func (r *SysOps) CleanupRouting() error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||||
if isLegacy() {
|
if isLegacy() {
|
||||||
return r.cleanupRefCounter()
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
@ -13,15 +13,16 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
return r.setupRefCounter(initAddresses)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting() error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||||
return r.cleanupRefCounter()
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -130,12 +131,12 @@ const (
|
|||||||
RouteDeleted
|
RouteDeleted
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||||
return r.setupRefCounter(initAddresses)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting() error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||||
return r.cleanupRefCounter()
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
|
298
client/internal/statemanager/manager.go
Normal file
298
client/internal/statemanager/manager.go
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
package statemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// State interface defines the methods that all state types must implement
|
||||||
|
type State interface {
|
||||||
|
Name() string
|
||||||
|
Cleanup() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager handles the persistence and management of various states
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
cancel context.CancelFunc
|
||||||
|
done chan struct{}
|
||||||
|
|
||||||
|
filePath string
|
||||||
|
// holds the states that are registered with the manager and that are to be persisted
|
||||||
|
states map[string]State
|
||||||
|
// holds the state names that have been updated and need to be persisted with the next save
|
||||||
|
dirty map[string]struct{}
|
||||||
|
// holds the type information for each registered state
|
||||||
|
stateTypes map[string]reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Manager instance
|
||||||
|
func New(filePath string) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
filePath: filePath,
|
||||||
|
states: make(map[string]State),
|
||||||
|
dirty: make(map[string]struct{}),
|
||||||
|
stateTypes: make(map[string]reflect.Type),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the state manager periodic save routine
|
||||||
|
func (m *Manager) Start() {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, m.cancel = context.WithCancel(context.Background())
|
||||||
|
m.done = make(chan struct{})
|
||||||
|
|
||||||
|
go m.periodicStateSave(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Stop(ctx context.Context) error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.cancel != nil {
|
||||||
|
m.cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-m.done:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterState registers a state with the manager but doesn't attempt to persist it.
|
||||||
|
// Pass an uninitialized state to register it.
|
||||||
|
func (m *Manager) RegisterState(state State) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
name := state.Name()
|
||||||
|
m.states[name] = nil
|
||||||
|
m.stateTypes[name] = reflect.TypeOf(state).Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetState returns the state for the given type
|
||||||
|
func (m *Manager) GetState(state State) State {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
return m.states[state.Name()]
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateState updates the state in the manager and marks it as dirty for the next save.
|
||||||
|
// The state will be replaced with the new one.
|
||||||
|
func (m *Manager) UpdateState(state State) error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.setState(state.Name(), state)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteState removes the state from the manager and marks it as dirty for the next save.
|
||||||
|
// Pass an uninitialized state to delete it.
|
||||||
|
func (m *Manager) DeleteState(state State) error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.setState(state.Name(), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) setState(name string, state State) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := m.states[name]; !exists {
|
||||||
|
return fmt.Errorf("state %s not registered", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.states[name] = state
|
||||||
|
m.dirty[name] = struct{}{}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) periodicStateSave(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(10 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
defer close(m.done)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := m.PersistState(ctx); err != nil {
|
||||||
|
log.Errorf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PersistState persists the states that have been updated since the last save.
|
||||||
|
func (m *Manager) PersistState(ctx context.Context) error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if len(m.dirty) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
data, err := json.MarshalIndent(m.states, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
done <- fmt.Errorf("marshal states: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:gosec
|
||||||
|
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
|
||||||
|
done <- fmt.Errorf("write state file: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done <- nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
|
||||||
|
|
||||||
|
clear(m.dirty)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadState loads the existing state from the state file
|
||||||
|
func (m *Manager) loadState() error {
|
||||||
|
data, err := os.ReadFile(m.filePath)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
log.Debug("state file does not exist")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("read state file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var rawStates map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(data, &rawStates); err != nil {
|
||||||
|
log.Warn("State file appears to be corrupted, attempting to delete it")
|
||||||
|
if err := os.Remove(m.filePath); err != nil {
|
||||||
|
log.Errorf("Failed to delete corrupted state file: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Info("State file deleted")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unmarshal states: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
for name, rawState := range rawStates {
|
||||||
|
stateType, ok := m.stateTypes[name]
|
||||||
|
if !ok {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(rawState) == "null" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
statePtr := reflect.New(stateType).Interface().(State)
|
||||||
|
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
m.states[name] = statePtr
|
||||||
|
log.Debugf("loaded state: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them.
|
||||||
|
// If the cleanup is successful, the state is marked for deletion.
|
||||||
|
func (m *Manager) PerformCleanup() error {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if err := m.loadState(); err != nil {
|
||||||
|
log.Warnf("Failed to load state during cleanup: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for name, state := range m.states {
|
||||||
|
if state == nil {
|
||||||
|
// If no state was found in the state file, we don't mark the state dirty nor return an error
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("client was not shut down properly, cleaning up %s", name)
|
||||||
|
if err := state.Cleanup(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
|
||||||
|
} else {
|
||||||
|
// mark for deletion on cleanup success
|
||||||
|
m.states[name] = nil
|
||||||
|
m.dirty[name] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
35
client/internal/statemanager/path.go
Normal file
35
client/internal/statemanager/path.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package statemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||||
|
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
|
||||||
|
func GetDefaultStatePath() string {
|
||||||
|
var path string
|
||||||
|
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||||
|
case "darwin", "linux":
|
||||||
|
path = "/var/lib/netbird/state.json"
|
||||||
|
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||||
|
path = "/var/db/netbird/state.json"
|
||||||
|
// ios/android don't need state
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return path
|
||||||
|
}
|
@ -138,12 +138,12 @@ func (c *Client) Stop() {
|
|||||||
c.ctxCancel()
|
c.ctxCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ÏSetTraceLogLevel configure the logger to trace level
|
// SetTraceLogLevel configure the logger to trace level
|
||||||
func (c *Client) SetTraceLogLevel() {
|
func (c *Client) SetTraceLogLevel() {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getStatusDetails return with the list of the PeerInfos
|
// GetStatusDetails return with the list of the PeerInfos
|
||||||
func (c *Client) GetStatusDetails() *StatusDetails {
|
func (c *Client) GetStatusDetails() *StatusDetails {
|
||||||
|
|
||||||
fullStatus := c.recorder.GetFullStatus()
|
fullStatus := c.recorder.GetFullStatus()
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
@ -20,7 +21,11 @@ import (
|
|||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
@ -39,6 +44,8 @@ const (
|
|||||||
defaultMaxRetryInterval = 60 * time.Minute
|
defaultMaxRetryInterval = 60 * time.Minute
|
||||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||||
defaultRetryMultiplier = 1.7
|
defaultRetryMultiplier = 1.7
|
||||||
|
|
||||||
|
errRestoreResidualState = "failed to restore residual state: %v"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server for service control.
|
// Server for service control.
|
||||||
@ -95,6 +102,10 @@ func (s *Server) Start() error {
|
|||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
|
|
||||||
|
if err := restoreResidualState(s.rootCtx); err != nil {
|
||||||
|
log.Warnf(errRestoreResidualState, err)
|
||||||
|
}
|
||||||
|
|
||||||
// if current state contains any error, return it
|
// if current state contains any error, return it
|
||||||
// in all other cases we can continue execution only if status is idle and up command was
|
// in all other cases we can continue execution only if status is idle and up command was
|
||||||
// not in the progress or already successfully established connection.
|
// not in the progress or already successfully established connection.
|
||||||
@ -292,6 +303,10 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
s.actCancel = cancel
|
s.actCancel = cancel
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
if err := restoreResidualState(ctx); err != nil {
|
||||||
|
log.Warnf(errRestoreResidualState, err)
|
||||||
|
}
|
||||||
|
|
||||||
state := internal.CtxGetState(ctx)
|
state := internal.CtxGetState(ctx)
|
||||||
defer func() {
|
defer func() {
|
||||||
status, err := state.Status()
|
status, err := state.Status()
|
||||||
@ -549,6 +564,10 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
|||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if err := restoreResidualState(callerCtx); err != nil {
|
||||||
|
log.Warnf(errRestoreResidualState, err)
|
||||||
|
}
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
state := internal.CtxGetState(s.rootCtx)
|
||||||
|
|
||||||
// if current state contains any error, return it
|
// if current state contains any error, return it
|
||||||
@ -829,3 +848,31 @@ func sendTerminalNotification() error {
|
|||||||
|
|
||||||
return wallCmd.Wait()
|
return wallCmd.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required.
|
||||||
|
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||||
|
func restoreResidualState(ctx context.Context) error {
|
||||||
|
path := statemanager.GetDefaultStatePath()
|
||||||
|
if path == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := statemanager.New(path)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// register the states we are interested in restoring
|
||||||
|
// this will also allow each subsystem to record its own state
|
||||||
|
mgr.RegisterState(&dns.ShutdownState{})
|
||||||
|
mgr.RegisterState(&systemops.ShutdownState{})
|
||||||
|
|
||||||
|
if err := mgr.PerformCleanup(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.PersistState(ctx); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user