Restore dns on unclean shutdown (#1494)

This commit is contained in:
Viktor Liu 2024-01-30 09:58:56 +01:00 committed by GitHub
parent 9c56f74235
commit 846d486366
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 660 additions and 143 deletions

View File

@ -11,11 +11,12 @@ import (
"github.com/kardianos/service" "github.com/kardianos/service"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/spf13/cobra"
"google.golang.org/grpc"
) )
func (p *program) Start(svc service.Service) error { func (p *program) Start(svc service.Service) error {
@ -109,7 +110,6 @@ var runCmd = &cobra.Command{
if err != nil { if err != nil {
return err return err
} }
cmd.Printf("Netbird service is running")
return nil return nil
}, },
} }

View File

@ -2,6 +2,7 @@ package internal
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@ -93,6 +94,12 @@ func runClient(
) error { ) error {
log.Infof("starting NetBird client version %s", version.NetbirdVersion()) log.Infof("starting NetBird client version %s", version.NetbirdVersion())
// Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
if err := dns.CheckUncleanShutdown(config.WgIface); err != nil {
log.Errorf("checking unclean shutdown error: %s", err)
}
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@ -244,7 +251,7 @@ func runClient(
log.Info("stopped NetBird client") log.Info("stopped NetBird client")
if _, err := state.Status(); err == ErrResetConnection { if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
return err return err
} }

View File

@ -4,9 +4,11 @@ package dns
import ( import (
"context" "context"
"fmt"
"time"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"time"
) )
const dbusDefaultFlag = 0 const dbusDefaultFlag = 0
@ -14,6 +16,7 @@ const dbusDefaultFlag = 0
func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool { func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool {
obj, closeConn, err := getDbusObject(dest, path) obj, closeConn, err := getDbusObject(dest, path)
if err != nil { if err != nil {
log.Tracef("error getting dbus object: %s", err)
return false return false
} }
defer closeConn() defer closeConn()
@ -21,14 +24,18 @@ func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool {
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
defer cancel() defer cancel()
err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store() if err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store(); err != nil {
return err == nil log.Tracef("error calling dbus: %s", err)
return false
}
return true
} }
func getDbusObject(dest string, path dbus.ObjectPath) (dbus.BusObject, func(), error) { func getDbusObject(dest string, path dbus.ObjectPath) (dbus.BusObject, func(), error) {
conn, err := dbus.SystemBus() conn, err := dbus.SystemBus()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, fmt.Errorf("get dbus: %w", err)
} }
obj := conn.Object(dest, path) obj := conn.Object(dest, path)

View File

@ -5,6 +5,7 @@ package dns
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net/netip"
"os" "os"
"strings" "strings"
@ -49,7 +50,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
if backupFileExist { if backupFileExist {
err = f.restore() err = f.restore()
if err != nil { if err != nil {
return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %s", err) return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %w", err)
} }
} }
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
@ -58,7 +59,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
if !backupFileExist { if !backupFileExist {
err = f.backup() err = f.backup()
if err != nil { if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file") return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
} }
} }
@ -67,7 +68,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
resolvConf, err := parseBackupResolvConf() resolvConf, err := parseBackupResolvConf()
if err != nil { if err != nil {
log.Error(err) log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err)
} }
f.repair.stopWatchFileChanges() f.repair.stopWatchFileChanges()
@ -96,10 +97,16 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
if restoreErr != nil { if restoreErr != nil {
log.Errorf("attempt to restore default file failed with error: %s", err) log.Errorf("attempt to restore default file failed with error: %s", err)
} }
return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err) return fmt.Errorf("creating resolver file %s. Error: %w", defaultResolvConfPath, err)
}
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
} }
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
return nil return nil
} }
@ -111,14 +118,14 @@ func (f *fileConfigurator) restoreHostDNS() error {
func (f *fileConfigurator) backup() error { func (f *fileConfigurator) backup() error {
stats, err := os.Stat(defaultResolvConfPath) stats, err := os.Stat(defaultResolvConfPath)
if err != nil { if err != nil {
return fmt.Errorf("got an error while checking stats for %s file. Error: %s", defaultResolvConfPath, err) return fmt.Errorf("checking stats for %s file. Error: %w", defaultResolvConfPath, err)
} }
f.originalPerms = stats.Mode() f.originalPerms = stats.Mode()
err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation) err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation)
if err != nil { if err != nil {
return fmt.Errorf("got error while backing up the %s file. Error: %s", defaultResolvConfPath, err) return fmt.Errorf("backing up %s: %w", defaultResolvConfPath, err)
} }
return nil return nil
} }
@ -126,12 +133,58 @@ func (f *fileConfigurator) backup() error {
func (f *fileConfigurator) restore() error { func (f *fileConfigurator) restore() error {
err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath) err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
if err != nil { if err != nil {
return fmt.Errorf("got error while restoring the %s file from %s. Error: %s", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
} }
return os.RemoveAll(fileDefaultResolvConfBackupLocation) return os.RemoveAll(fileDefaultResolvConfBackupLocation)
} }
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
resolvConf, err := parseDefaultResolvConf()
if err != nil {
return fmt.Errorf("parse current resolv.conf: %w", err)
}
// no current nameservers set -> restore
if len(resolvConf.nameServers) == 0 {
return restoreResolvConfFile()
}
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
// not a valid first nameserver -> restore
if err != nil {
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[1], err)
return restoreResolvConfFile()
}
// current address is still netbird's non-available dns address -> restore
// comparing parsed addresses only, to remove ambiguity
if currentDNSAddress.String() == storedDNSAddress.String() {
return restoreResolvConfFile()
}
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
return nil
}
func restoreResolvConfFile() error {
log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)
if err := copyFile(fileUncleanShutdownResolvConfLocation, defaultResolvConfPath); err != nil {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
}
return nil
}
// generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list // generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list
func generateNsList(nbNameserverIP string, cfg *resolvConf) []string { func generateNsList(nbNameserverIP string, cfg *resolvConf) []string {
ns := make([]string, 1, len(cfg.nameServers)+1) ns := make([]string, 1, len(cfg.nameServers)+1)
@ -231,17 +284,17 @@ func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string
func copyFile(src, dest string) error { func copyFile(src, dest string) error {
stats, err := os.Stat(src) stats, err := os.Stat(src)
if err != nil { if err != nil {
return fmt.Errorf("got an error while checking stats for %s file when copying it. Error: %s", src, err) return fmt.Errorf("checking stats for %s file when copying it. Error: %s", src, err)
} }
bytesRead, err := os.ReadFile(src) bytesRead, err := os.ReadFile(src)
if err != nil { if err != nil {
return fmt.Errorf("got an error while reading the file %s file for copy. Error: %s", src, err) return fmt.Errorf("reading the file %s file for copy. Error: %s", src, err)
} }
err = os.WriteFile(dest, bytesRead, stats.Mode()) err = os.WriteFile(dest, bytesRead, stats.Mode())
if err != nil { if err != nil {
return fmt.Errorf("got an writing the destination file %s for copy. Error: %s", dest, err) return fmt.Errorf("writing the destination file %s for copy. Error: %s", dest, err)
} }
return nil return nil
} }

View File

@ -2,6 +2,7 @@ package dns
import ( import (
"fmt" "fmt"
"net/netip"
"strings" "strings"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@ -11,6 +12,7 @@ type hostManager interface {
applyDNSConfig(config HostDNSConfig) error applyDNSConfig(config HostDNSConfig) error
restoreHostDNS() error restoreHostDNS() error
supportCustomPort() bool supportCustomPort() bool
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
} }
type HostDNSConfig struct { type HostDNSConfig struct {
@ -27,9 +29,10 @@ type DomainConfig struct {
} }
type mockHostConfigurator struct { type mockHostConfigurator struct {
applyDNSConfigFunc func(config HostDNSConfig) error applyDNSConfigFunc func(config HostDNSConfig) error
restoreHostDNSFunc func() error restoreHostDNSFunc func() error
supportCustomPortFunc func() bool supportCustomPortFunc func() bool
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
} }
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error { func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
@ -53,11 +56,19 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
return false return false
} }
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
if m.restoreUncleanShutdownDNSFunc != nil {
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
}
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
}
func newNoopHostMocker() hostManager { func newNoopHostMocker() hostManager {
return &mockHostConfigurator{ return &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil }, applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
restoreHostDNSFunc: func() error { return nil }, restoreHostDNSFunc: func() error { return nil },
supportCustomPortFunc: func() bool { return true }, supportCustomPortFunc: func() bool { return true },
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },
} }
} }

View File

@ -1,9 +1,11 @@
package dns package dns
import "net/netip"
type androidHostManager struct { type androidHostManager struct {
} }
func newHostManager(wgInterface WGIface) (hostManager, error) { func newHostManager() (hostManager, error) {
return &androidHostManager{}, nil return &androidHostManager{}, nil
} }
@ -18,3 +20,7 @@ func (a androidHostManager) restoreHostDNS() error {
func (a androidHostManager) supportCustomPort() bool { func (a androidHostManager) supportCustomPort() bool {
return false return false
} }
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@ -6,6 +6,8 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"fmt" "fmt"
"io"
"net/netip"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
@ -34,7 +36,7 @@ type systemConfigurator struct {
createdKeys map[string]struct{} createdKeys map[string]struct{}
} }
func newHostManager(_ WGIface) (hostManager, error) { func newHostManager() (hostManager, error) {
return &systemConfigurator{ return &systemConfigurator{
createdKeys: make(map[string]struct{}), createdKeys: make(map[string]struct{}),
}, nil }, nil
@ -50,17 +52,22 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
if config.RouteAll { if config.RouteAll {
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort) err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
if err != nil { if err != nil {
return err return fmt.Errorf("add dns setup for all: %w", err)
} }
} else if s.primaryServiceID != "" { } else if s.primaryServiceID != "" {
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
if err != nil { if err != nil {
return err return fmt.Errorf("remote key from system config: %w", err)
} }
s.primaryServiceID = "" s.primaryServiceID = ""
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort) log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
} }
// create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err)
}
var ( var (
searchDomains []string searchDomains []string
matchDomains []string matchDomains []string
@ -85,7 +92,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
err = s.removeKeyFromSystemConfig(matchKey) err = s.removeKeyFromSystemConfig(matchKey)
} }
if err != nil { if err != nil {
return err return fmt.Errorf("add match domains: %w", err)
} }
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
@ -96,7 +103,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
err = s.removeKeyFromSystemConfig(searchKey) err = s.removeKeyFromSystemConfig(searchKey)
} }
if err != nil { if err != nil {
return err return fmt.Errorf("add search domains: %w", err)
} }
return nil return nil
@ -119,7 +126,11 @@ func (s *systemConfigurator) restoreHostDNS() error {
_, err := runSystemConfigCommand(wrapCommand(lines)) _, err := runSystemConfigCommand(wrapCommand(lines))
if err != nil { if err != nil {
log.Errorf("got an error while cleaning the system configuration: %s", err) log.Errorf("got an error while cleaning the system configuration: %s", err)
return err return fmt.Errorf("clean system: %w", err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
} }
return nil return nil
@ -129,7 +140,7 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key) line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line)) _, err := runSystemConfigCommand(wrapCommand(line))
if err != nil { if err != nil {
return err return fmt.Errorf("remove key: %w", err)
} }
delete(s.createdKeys, key) delete(s.createdKeys, key)
@ -140,7 +151,7 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
err := s.addDNSState(key, domains, ip, port, true) err := s.addDNSState(key, domains, ip, port, true)
if err != nil { if err != nil {
return err return fmt.Errorf("add dns state: %w", err)
} }
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains) log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
@ -153,7 +164,7 @@ func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, po
func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error { func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error {
err := s.addDNSState(key, domains, dnsServer, port, false) err := s.addDNSState(key, domains, dnsServer, port, false)
if err != nil { if err != nil {
return err return fmt.Errorf("add dns state: %w", err)
} }
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains) log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
@ -178,33 +189,37 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
_, err := runSystemConfigCommand(stdinCommands) _, err := runSystemConfigCommand(stdinCommands)
if err != nil { if err != nil {
return fmt.Errorf("got error while applying state for domains %s, error: %s", domains, err) return fmt.Errorf("applying state for domains %s, error: %w", domains, err)
} }
return nil return nil
} }
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error { func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey, existingNameserver := s.getPrimaryService() primaryServiceKey, existingNameserver, err := s.getPrimaryService()
if primaryServiceKey == "" { if err != nil || primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key") return fmt.Errorf("couldn't find the primary service key: %w", err)
} }
err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
if err != nil { if err != nil {
return err return fmt.Errorf("add dns setup: %w", err)
} }
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port) log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
s.primaryServiceID = primaryServiceKey s.primaryServiceID = primaryServiceKey
return nil return nil
} }
func (s *systemConfigurator) getPrimaryService() (string, string) { func (s *systemConfigurator) getPrimaryService() (string, string, error) {
line := buildCommandLine("show", globalIPv4State, "") line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line) stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands) b, err := runSystemConfigCommand(stdinCommands)
if err != nil { if err != nil {
log.Error("got error while sending the command: ", err) return "", "", fmt.Errorf("sending the command: %w", err)
return "", ""
} }
scanner := bufio.NewScanner(bytes.NewReader(b)) scanner := bufio.NewScanner(bytes.NewReader(b))
primaryService := "" primaryService := ""
router := "" router := ""
@ -217,7 +232,11 @@ func (s *systemConfigurator) getPrimaryService() (string, string) {
router = strings.TrimSpace(strings.Split(text, ":")[1]) router = strings.TrimSpace(strings.Split(text, ":")[1])
} }
} }
return primaryService, router if err := scanner.Err(); err != nil && err != io.EOF {
return primaryService, router, fmt.Errorf("scan: %w", err)
}
return primaryService, router, nil
} }
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error { func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
@ -228,7 +247,14 @@ func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, e
stdinCommands := wrapCommand(addDomainCommand) stdinCommands := wrapCommand(addDomainCommand)
_, err := runSystemConfigCommand(stdinCommands) _, err := runSystemConfigCommand(stdinCommands)
if err != nil { if err != nil {
return fmt.Errorf("got error while applying dns setup, error: %s", err) return fmt.Errorf("applying dns setup, error: %w", err)
}
return nil
}
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err)
} }
return nil return nil
} }
@ -266,7 +292,7 @@ func runSystemConfigCommand(command string) ([]byte, error) {
cmd.Stdin = strings.NewReader(command) cmd.Stdin = strings.NewReader(command)
out, err := cmd.Output() out, err := cmd.Output()
if err != nil { if err != nil {
return nil, fmt.Errorf("got error while running system configuration command: \"%s\", error: %s", command, err) return nil, fmt.Errorf("running system configuration command: \"%s\", error: %w", command, err)
} }
return out, nil return out, nil
} }

View File

@ -2,6 +2,8 @@ package dns
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -20,7 +22,7 @@ func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error { func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
jsonData, err := json.Marshal(config) jsonData, err := json.Marshal(config)
if err != nil { if err != nil {
return err return fmt.Errorf("marshal: %w", err)
} }
jsonString := string(jsonData) jsonString := string(jsonData)
log.Debugf("Applying DNS settings: %s", jsonString) log.Debugf("Applying DNS settings: %s", jsonString)
@ -35,3 +37,7 @@ func (a iosHostManager) restoreHostDNS() error {
func (a iosHostManager) supportCustomPort() bool { func (a iosHostManager) supportCustomPort() bool {
return false return false
} }
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@ -4,7 +4,9 @@ package dns
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io"
"os" "os"
"strings" "strings"
@ -19,8 +21,27 @@ const (
resolvConfManager resolvConfManager
) )
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
type osManagerType int type osManagerType int
func newOsManagerType(osManager string) (osManagerType, error) {
switch osManager {
case "netbird":
return fileManager, nil
case "file":
return netbirdManager, nil
case "networkManager":
return networkManager, nil
case "systemd":
return systemdManager, nil
case "resolvconf":
return resolvConfManager, nil
default:
return 0, ErrUnknownOsManagerType
}
}
func (t osManagerType) String() string { func (t osManagerType) String() string {
switch t { switch t {
case netbirdManager: case netbirdManager:
@ -38,13 +59,17 @@ func (t osManagerType) String() string {
} }
} }
func newHostManager(wgInterface WGIface) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("discovered mode is: %s", osManager) log.Debugf("discovered mode is: %s", osManager)
return newHostManagerFromType(wgInterface, osManager)
}
func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) {
switch osManager { switch osManager {
case networkManager: case networkManager:
return newNetworkManagerDbusConfigurator(wgInterface) return newNetworkManagerDbusConfigurator(wgInterface)
@ -58,12 +83,15 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
} }
func getOSDNSManagerType() (osManagerType, error) { func getOSDNSManagerType() (osManagerType, error) {
file, err := os.Open(defaultResolvConfPath) file, err := os.Open(defaultResolvConfPath)
if err != nil { if err != nil {
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %s", defaultResolvConfPath, err) return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
} }
defer file.Close() defer func() {
if err := file.Close(); err != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
}
}()
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
for scanner.Scan() { for scanner.Scan() {
@ -101,6 +129,10 @@ func getOSDNSManagerType() (osManagerType, error) {
return resolvConfManager, nil return resolvConfManager, nil
} }
} }
if err := scanner.Err(); err != nil && err != io.EOF {
return 0, fmt.Errorf("scan: %w", err)
}
return fileManager, nil return fileManager, nil
} }

View File

@ -2,6 +2,8 @@ package dns
import ( import (
"fmt" "fmt"
"io"
"net/netip"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -9,7 +11,7 @@ import (
) )
const ( const (
dnsPolicyConfigMatchPath = "SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters\\DnsPolicyConfig\\NetBird-Match" dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
dnsPolicyConfigVersionKey = "Version" dnsPolicyConfigVersionKey = "Version"
dnsPolicyConfigVersionValue = 2 dnsPolicyConfigVersionValue = 2
dnsPolicyConfigNameKey = "Name" dnsPolicyConfigNameKey = "Name"
@ -19,7 +21,7 @@ const (
) )
const ( const (
interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces" interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer" interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList" interfaceConfigSearchListKey = "SearchList"
) )
@ -34,12 +36,16 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newHostManagerWithGuid(guid)
}
func newHostManagerWithGuid(guid string) (hostManager, error) {
return &registryConfigurator{ return &registryConfigurator{
guid: guid, guid: guid,
}, nil }, nil
} }
func (s *registryConfigurator) supportCustomPort() bool { func (r *registryConfigurator) supportCustomPort() bool {
return false return false
} }
@ -48,17 +54,22 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
if config.RouteAll { if config.RouteAll {
err = r.addDNSSetupForAll(config.ServerIP) err = r.addDNSSetupForAll(config.ServerIP)
if err != nil { if err != nil {
return err return fmt.Errorf("add dns setup: %w", err)
} }
} else if r.routingAll { } else if r.routingAll {
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey) err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
if err != nil { if err != nil {
return err return fmt.Errorf("delete interface registry key property: %w", err)
} }
r.routingAll = false r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
// create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(r.guid); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err)
}
var ( var (
searchDomains []string searchDomains []string
matchDomains []string matchDomains []string
@ -80,12 +91,12 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath) err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
} }
if err != nil { if err != nil {
return err return fmt.Errorf("add dns match policy: %w", err)
} }
err = r.updateSearchDomains(searchDomains) err = r.updateSearchDomains(searchDomains)
if err != nil { if err != nil {
return err return fmt.Errorf("update search domains: %w", err)
} }
return nil return nil
@ -94,7 +105,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (r *registryConfigurator) addDNSSetupForAll(ip string) error { func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip) err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
if err != nil { if err != nil {
return fmt.Errorf("adding dns setup for all failed with error: %s", err) return fmt.Errorf("adding dns setup for all failed with error: %w", err)
} }
r.routingAll = true r.routingAll = true
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip) log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
@ -106,33 +117,33 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
if err == nil { if err == nil {
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath) err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
if err != nil { if err != nil {
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err) return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
} }
} }
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE) regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
if err != nil { if err != nil {
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err) return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
} }
err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue) err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
if err != nil { if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigVersionKey, err) return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigVersionKey, err)
} }
err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains) err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
if err != nil { if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigNameKey, err) return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigNameKey, err)
} }
err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip) err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
if err != nil { if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigGenericDNSServersKey, err) return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigGenericDNSServersKey, err)
} }
err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue) err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
if err != nil { if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigConfigOptionsKey, err) return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigConfigOptionsKey, err)
} }
log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains) log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)
@ -141,18 +152,25 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
} }
func (r *registryConfigurator) restoreHostDNS() error { func (r *registryConfigurator) restoreHostDNS() error {
err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath) if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
if err != nil { log.Errorf("remove registry key from dns policy config: %s", err)
log.Error(err)
} }
return r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey) if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
return fmt.Errorf("remove interface registry key: %w", err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
}
return nil
} }
func (r *registryConfigurator) updateSearchDomains(domains []string) error { func (r *registryConfigurator) updateSearchDomains(domains []string) error {
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")) err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
if err != nil { if err != nil {
return fmt.Errorf("adding search domain failed with error: %s", err) return fmt.Errorf("adding search domain failed with error: %w", err)
} }
log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains) log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)
@ -163,13 +181,13 @@ func (r *registryConfigurator) updateSearchDomains(domains []string) error {
func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error { func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error {
regKey, err := r.getInterfaceRegistryKey() regKey, err := r.getInterfaceRegistryKey()
if err != nil { if err != nil {
return err return fmt.Errorf("get interface registry key: %w", err)
} }
defer regKey.Close() defer closer(regKey)
err = regKey.SetStringValue(key, value) err = regKey.SetStringValue(key, value)
if err != nil { if err != nil {
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %s", key, value, err) return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %w", key, value, err)
} }
return nil return nil
@ -178,13 +196,13 @@ func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value str
func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey string) error { func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey string) error {
regKey, err := r.getInterfaceRegistryKey() regKey, err := r.getInterfaceRegistryKey()
if err != nil { if err != nil {
return err return fmt.Errorf("get interface registry key: %w", err)
} }
defer regKey.Close() defer closer(regKey)
err = regKey.DeleteValue(propertyKey) err = regKey.DeleteValue(propertyKey)
if err != nil { if err != nil {
return fmt.Errorf("deleting registry key %s for interface failed with error: %s", propertyKey, err) return fmt.Errorf("deleting registry key %s for interface failed with error: %w", propertyKey, err)
} }
return nil return nil
@ -197,20 +215,33 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE) regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
if err != nil { if err != nil {
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err) return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
} }
return regKey, nil return regKey, nil
} }
func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via registry: %w", err)
}
return nil
}
func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error { func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE) k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
if err == nil { if err == nil {
k.Close() defer closer(k)
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath) err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
if err != nil { if err != nil {
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err) return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
} }
} }
return nil return nil
} }
func closer(closer io.Closer) {
if err := closer.Close(); err != nil {
log.Errorf("failed to close: %s", err)
}
}

View File

@ -52,7 +52,7 @@ func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error { func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
fullRecord, err := dns.NewRR(record.String()) fullRecord, err := dns.NewRR(record.String())
if err != nil { if err != nil {
return err return fmt.Errorf("register record: %w", err)
} }
fullRecord.Header().Rdlength = record.Len() fullRecord.Header().Rdlength = record.Len()

View File

@ -5,8 +5,10 @@ package dns
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"strings"
"time" "time"
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
@ -41,9 +43,13 @@ const (
networkManagerDbusPrimaryDNSPriority int32 = -500 networkManagerDbusPrimaryDNSPriority int32 = -500
networkManagerDbusWithMatchDomainPriority int32 = 0 networkManagerDbusWithMatchDomainPriority int32 = 0
networkManagerDbusSearchDomainOnlyPriority int32 = 50 networkManagerDbusSearchDomainOnlyPriority int32 = 50
supportedNetworkManagerVersionConstraint = ">= 1.16, < 1.28"
) )
var supportedNetworkManagerVersionConstraints = []string{
">= 1.16, < 1.27",
">= 1.44, < 1.45",
}
type networkManagerDbusConfigurator struct { type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath dbusLinkObject dbus.ObjectPath
routingAll bool routingAll bool
@ -71,19 +77,19 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
} }
} }
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) { func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("get nm dbus: %w", err)
} }
defer closeConn() defer closeConn()
var s string var s string
err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.Name()).Store(&s) err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface).Store(&s)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("call: %w", err)
} }
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.Name()) log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
return &networkManagerDbusConfigurator{ return &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s), dbusLinkObject: dbus.ObjectPath(s),
@ -97,14 +103,14 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
connSettings, configVersion, err := n.getAppliedConnectionSettings() connSettings, configVersion, err := n.getAppliedConnectionSettings()
if err != nil { if err != nil {
return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err) return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
} }
connSettings.cleanDeprecatedSettings() connSettings.cleanDeprecatedSettings()
dnsIP, err := netip.ParseAddr(config.ServerIP) dnsIP, err := netip.ParseAddr(config.ServerIP)
if err != nil { if err != nil {
return fmt.Errorf("unable to parse ip address, error: %s", err) return fmt.Errorf("unable to parse ip address, error: %w", err)
} }
convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice()) convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice())
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
@ -145,23 +151,37 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
// The file content itself is not important for network-manager restoration
if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
}
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
err = n.reApplyConnectionSettings(connSettings, configVersion) err = n.reApplyConnectionSettings(connSettings, configVersion)
if err != nil { if err != nil {
return fmt.Errorf("got an error while reapplying the connection with new settings, error: %s", err) return fmt.Errorf("reapplying the connection with new settings, error: %w", err)
} }
return nil return nil
} }
func (n *networkManagerDbusConfigurator) restoreHostDNS() error { func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
// once the interface is gone network manager cleans all config associated with it // once the interface is gone network manager cleans all config associated with it
return n.deleteConnectionSettings() if err := n.deleteConnectionSettings(); err != nil {
return fmt.Errorf("delete connection settings: %w", err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return nil
} }
func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) { func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject) obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err) return nil, 0, fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
} }
defer closeConn() defer closeConn()
@ -176,7 +196,7 @@ func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (network
err = obj.CallWithContext(ctx, networkManagerDbusDeviceGetAppliedConnectionMethod, dbusDefaultFlag, err = obj.CallWithContext(ctx, networkManagerDbusDeviceGetAppliedConnectionMethod, dbusDefaultFlag,
networkManagerDbusDefaultBehaviorFlag).Store(&connSettings, &configVersion) networkManagerDbusDefaultBehaviorFlag).Store(&connSettings, &configVersion)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("got error while calling GetAppliedConnection method with context, err: %s", err) return nil, 0, fmt.Errorf("calling GetAppliedConnection method with context, err: %w", err)
} }
return connSettings, configVersion, nil return connSettings, configVersion, nil
@ -185,7 +205,7 @@ func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (network
func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings networkManagerConnSettings, configVersion networkManagerConfigVersion) error { func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings networkManagerConnSettings, configVersion networkManagerConfigVersion) error {
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject) obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
if err != nil { if err != nil {
return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err) return fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
} }
defer closeConn() defer closeConn()
@ -195,7 +215,7 @@ func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings
err = obj.CallWithContext(ctx, networkManagerDbusDeviceReapplyMethod, dbusDefaultFlag, err = obj.CallWithContext(ctx, networkManagerDbusDeviceReapplyMethod, dbusDefaultFlag,
connSettings, configVersion, networkManagerDbusDefaultBehaviorFlag).Store() connSettings, configVersion, networkManagerDbusDefaultBehaviorFlag).Store()
if err != nil { if err != nil {
return fmt.Errorf("got error while calling ReApply method with context, err: %s", err) return fmt.Errorf("calling ReApply method with context, err: %w", err)
} }
return nil return nil
@ -204,21 +224,34 @@ func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings
func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error { func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject) obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject)
if err != nil { if err != nil {
return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err) return fmt.Errorf("attempting to retrieve the applied connection settings, err: %w", err)
} }
defer closeConn() defer closeConn()
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
defer cancel() defer cancel()
// this call is required to remove the device for DNS cleanup, even if it fails
err = obj.CallWithContext(ctx, networkManagerDbusDeviceDeleteMethod, dbusDefaultFlag).Store() err = obj.CallWithContext(ctx, networkManagerDbusDeviceDeleteMethod, dbusDefaultFlag).Store()
if err != nil { if err != nil {
return fmt.Errorf("got error while calling delete method with context, err: %s", err) var dbusErr dbus.Error
if errors.As(err, &dbusErr) && dbusErr.Name == dbus.ErrMsgUnknownMethod.Name {
// interface is gone already
return nil
}
return fmt.Errorf("calling delete method with context, err: %s", err)
} }
return nil return nil
} }
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := n.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via network-manager: %w", err)
}
return nil
}
func isNetworkManagerSupported() bool { func isNetworkManagerSupported() bool {
return isNetworkManagerSupportedVersion() && isNetworkManagerSupportedMode() return isNetworkManagerSupportedVersion() && isNetworkManagerSupportedMode()
} }
@ -250,13 +283,13 @@ func isNetworkManagerSupportedMode() bool {
func getNetworkManagerDNSProperty(property string, store any) error { func getNetworkManagerDNSProperty(property string, store any) error {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusDNSManagerObjectNode) obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusDNSManagerObjectNode)
if err != nil { if err != nil {
return fmt.Errorf("got error while attempting to retrieve the network manager dns manager object, error: %s", err) return fmt.Errorf("attempting to retrieve the network manager dns manager object, error: %w", err)
} }
defer closeConn() defer closeConn()
v, e := obj.GetProperty(property) v, e := obj.GetProperty(property)
if e != nil { if e != nil {
return fmt.Errorf("got an error getting property %s: %v", property, e) return fmt.Errorf("getting property %s: %w", property, e)
} }
return v.Store(store) return v.Store(store)
@ -278,15 +311,26 @@ func isNetworkManagerSupportedVersion() bool {
} }
versionValue, err := parseVersion(value.Value().(string)) versionValue, err := parseVersion(value.Value().(string))
if err != nil { if err != nil {
log.Errorf("nm: parse version: %s", err)
return false return false
} }
constraints, err := version.NewConstraint(supportedNetworkManagerVersionConstraint) var supported bool
if err != nil { for _, constraint := range supportedNetworkManagerVersionConstraints {
return false constr, err := version.NewConstraint(constraint)
if err != nil {
log.Errorf("nm: create constraint: %s", err)
return false
}
if met := constr.Check(versionValue); met {
supported = true
break
}
} }
return constraints.Check(versionValue) log.Debugf("network manager constraints [%s] met: %t", strings.Join(supportedNetworkManagerVersionConstraints, " | "), supported)
return supported
} }
func parseVersion(inputVersion string) (*version.Version, error) { func parseVersion(inputVersion string) (*version.Version, error) {

View File

@ -5,6 +5,7 @@ package dns
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net/netip"
"os/exec" "os/exec"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -21,14 +22,14 @@ type resolvconf struct {
} }
// supported "openresolv" only // supported "openresolv" only
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) { func newResolvConfConfigurator(wgInterface string) (hostManager, error) {
resolvConfEntries, err := parseDefaultResolvConf() resolvConfEntries, err := parseDefaultResolvConf()
if err != nil { if err != nil {
log.Error(err) log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
} }
return &resolvconf{ return &resolvconf{
ifaceName: wgInterface.Name(), ifaceName: wgInterface,
originalSearchDomains: resolvConfEntries.searchDomains, originalSearchDomains: resolvConfEntries.searchDomains,
originalNameServers: resolvConfEntries.nameServers, originalNameServers: resolvConfEntries.nameServers,
othersConfigs: resolvConfEntries.others, othersConfigs: resolvConfEntries.others,
@ -44,7 +45,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
if !config.RouteAll { if !config.RouteAll {
err = r.restoreHostDNS() err = r.restoreHostDNS()
if err != nil { if err != nil {
log.Error(err) log.Errorf("restore host dns: %s", err)
} }
return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured") return fmt.Errorf("unable to configure DNS for this peer using resolvconf manager without a nameserver group with all domains configured")
} }
@ -57,9 +58,14 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
append([]string{config.ServerIP}, r.originalNameServers...), append([]string{config.ServerIP}, r.originalNameServers...),
r.othersConfigs) r.othersConfigs)
// create a backup for unclean shutdown detection before the resolv.conf is changed
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
}
err = r.applyConfig(buf) err = r.applyConfig(buf)
if err != nil { if err != nil {
return err return fmt.Errorf("apply config: %w", err)
} }
log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList) log.Infof("added %d search domains. Search list: %s", len(searchDomainList), searchDomainList)
@ -67,20 +73,34 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
} }
func (r *resolvconf) restoreHostDNS() error { func (r *resolvconf) restoreHostDNS() error {
// openresolv only, debian resolvconf doesn't support "-f"
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName) cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
_, err := cmd.Output() _, err := cmd.Output()
if err != nil { if err != nil {
return fmt.Errorf("got an error while removing resolvconf configuration for %s interface, error: %s", r.ifaceName, err) return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return nil return nil
} }
func (r *resolvconf) applyConfig(content bytes.Buffer) error { func (r *resolvconf) applyConfig(content bytes.Buffer) error {
// openresolv only, debian resolvconf doesn't support "-x"
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName) cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
cmd.Stdin = &content cmd.Stdin = &content
_, err := cmd.Output() _, err := cmd.Output()
if err != nil { if err != nil {
return fmt.Errorf("got an error while applying resolvconf configuration for %s interface, error: %s", r.ifaceName, err) return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
}
return nil
}
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
} }
return nil return nil
} }

View File

@ -31,10 +31,13 @@ func (r *responseWriter) RemoteAddr() net.Addr {
func (r *responseWriter) WriteMsg(msg *dns.Msg) error { func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
buff, err := msg.Pack() buff, err := msg.Pack()
if err != nil { if err != nil {
return err return fmt.Errorf("pack: %w", err)
} }
_, err = r.Write(buff)
return err if _, err := r.Write(buff); err != nil {
return fmt.Errorf("write: %w", err)
}
return nil
} }
// Write writes a raw buffer back to the client. // Write writes a raw buffer back to the client.

View File

@ -142,12 +142,15 @@ func (s *DefaultServer) Initialize() (err error) {
if s.permanent { if s.permanent {
err = s.service.Listen() err = s.service.Listen()
if err != nil { if err != nil {
return err return fmt.Errorf("service listen: %w", err)
} }
} }
s.hostManager, err = s.initialize() s.hostManager, err = s.initialize()
return err if err != nil {
return fmt.Errorf("initialize: %w", err)
}
return nil
} }
// DnsIP returns the DNS resolver server IP address // DnsIP returns the DNS resolver server IP address
@ -225,7 +228,7 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
} }
if err := s.applyConfiguration(update); err != nil { if err := s.applyConfiguration(update); err != nil {
return err return fmt.Errorf("apply configuration: %w", err)
} }
s.updateSerial = serial s.updateSerial = serial

View File

@ -1,5 +1,5 @@
package dns package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) { func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface) return newHostManager()
} }

View File

@ -3,5 +3,5 @@
package dns package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) { func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface) return newHostManager()
} }

View File

@ -3,5 +3,5 @@
package dns package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) { func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface) return newHostManager(s.wgInterface.Name())
} }

View File

@ -63,7 +63,7 @@ func (s *serviceViaListener) Listen() error {
s.listenIP, s.listenPort, err = s.evalListenAddress() s.listenIP, s.listenPort, err = s.evalListenAddress()
if err != nil { if err != nil {
log.Errorf("failed to eval runtime address: %s", err) log.Errorf("failed to eval runtime address: %s", err)
return err return fmt.Errorf("eval listen address: %w", err)
} }
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort) s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)

View File

@ -44,7 +44,7 @@ func (s *serviceViaMemory) Listen() error {
var err error var err error
s.udpFilterHookID, err = s.filterDNSTraffic() s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil { if err != nil {
return err return fmt.Errorf("filter dns traffice: %w", err)
} }
s.listenerIsRunning = true s.listenerIsRunning = true

View File

@ -4,6 +4,7 @@ package dns
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@ -30,6 +31,8 @@ const (
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusResolvConfModeForeign = "foreign" systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
) )
type systemdDbusConfigurator struct { type systemdDbusConfigurator struct {
@ -52,22 +55,22 @@ type systemdDbusLinkDomainsInput struct {
MatchOnly bool MatchOnly bool
} }
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) { func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
iface, err := net.InterfaceByName(wgInterface.Name()) iface, err := net.InterfaceByName(wgInterface)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("get interface: %w", err)
} }
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("get dbus resolved dest: %w", err)
} }
defer closeConn() defer closeConn()
var s string var s string
err = obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, iface.Index).Store(&s) err = obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, iface.Index).Store(&s)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("get dbus link method: %w", err)
} }
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index) log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
parsedIP, err := netip.ParseAddr(config.ServerIP) parsedIP, err := netip.ParseAddr(config.ServerIP)
if err != nil { if err != nil {
return fmt.Errorf("unable to parse ip address, error: %s", err) return fmt.Errorf("unable to parse ip address, error: %w", err)
} }
ipAs4 := parsedIP.As4() ipAs4 := parsedIP.As4()
defaultLinkInput := systemdDbusDNSInput{ defaultLinkInput := systemdDbusDNSInput{
@ -93,7 +96,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
} }
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}) err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
if err != nil { if err != nil {
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.ServerIP, config.ServerPort, err) return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err)
} }
var ( var (
@ -121,7 +124,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true) err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
if err != nil { if err != nil {
return fmt.Errorf("setting link as default dns router, failed with error: %s", err) return fmt.Errorf("setting link as default dns router, failed with error: %w", err)
} }
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: nbdns.RootZone, Domain: nbdns.RootZone,
@ -132,6 +135,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
} }
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
// The file content itself is not important for systemd restoration
if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
}
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
err = s.setDomainsForInterface(domainsInput) err = s.setDomainsForInterface(domainsInput)
if err != nil { if err != nil {
@ -143,7 +152,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdDbusLinkDomainsInput) error { func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdDbusLinkDomainsInput) error {
err := s.callLinkMethod(systemdDbusSetDomainsMethodSuffix, domainsInput) err := s.callLinkMethod(systemdDbusSetDomainsMethodSuffix, domainsInput)
if err != nil { if err != nil {
return fmt.Errorf("setting domains configuration failed with error: %s", err) return fmt.Errorf("setting domains configuration failed with error: %w", err)
} }
return s.flushCaches() return s.flushCaches()
} }
@ -153,17 +162,29 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
if !isDbusListenerRunning(systemdResolvedDest, s.dbusLinkObject) { if !isDbusListenerRunning(systemdResolvedDest, s.dbusLinkObject) {
return nil return nil
} }
// this call is required for DNS cleanup, even if it fails
err := s.callLinkMethod(systemdDbusRevertMethodSuffix, nil) err := s.callLinkMethod(systemdDbusRevertMethodSuffix, nil)
if err != nil { if err != nil {
return fmt.Errorf("unable to revert link configuration, got error: %s", err) var dbusErr dbus.Error
if errors.As(err, &dbusErr) && dbusErr.Name == dbusErrorUnknownObject {
// interface is gone already
return nil
}
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return s.flushCaches() return s.flushCaches()
} }
func (s *systemdDbusConfigurator) flushCaches() error { func (s *systemdDbusConfigurator) flushCaches() error {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil { if err != nil {
return fmt.Errorf("got error while attempting to retrieve the object %s, err: %s", systemdDbusObjectNode, err) return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
} }
defer closeConn() defer closeConn()
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
@ -171,7 +192,7 @@ func (s *systemdDbusConfigurator) flushCaches() error {
err = obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, dbusDefaultFlag).Store() err = obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, dbusDefaultFlag).Store()
if err != nil { if err != nil {
return fmt.Errorf("got error while calling the FlushCaches method with context, err: %s", err) return fmt.Errorf("calling the FlushCaches method with context, err: %w", err)
} }
return nil return nil
@ -180,7 +201,7 @@ func (s *systemdDbusConfigurator) flushCaches() error {
func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error { func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error {
obj, closeConn, err := getDbusObject(systemdResolvedDest, s.dbusLinkObject) obj, closeConn, err := getDbusObject(systemdResolvedDest, s.dbusLinkObject)
if err != nil { if err != nil {
return fmt.Errorf("got error while attempting to retrieve the object, err: %s", err) return fmt.Errorf("attempting to retrieve the object, err: %w", err)
} }
defer closeConn() defer closeConn()
@ -194,22 +215,29 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
} }
if err != nil { if err != nil {
return fmt.Errorf("got error while calling command with context, err: %s", err) return fmt.Errorf("calling command with context, err: %w", err)
} }
return nil return nil
} }
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via systemd: %w", err)
}
return nil
}
func getSystemdDbusProperty(property string, store any) error { func getSystemdDbusProperty(property string, store any) error {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil { if err != nil {
return fmt.Errorf("got error while attempting to retrieve the systemd dns manager object, error: %s", err) return fmt.Errorf("attempting to retrieve the systemd dns manager object, error: %w", err)
} }
defer closeConn() defer closeConn()
v, e := obj.GetProperty(property) v, e := obj.GetProperty(property)
if e != nil { if e != nil {
return fmt.Errorf("got an error getting property %s: %v", property, e) return fmt.Errorf("getting property %s: %w", property, e)
} }
return v.Store(store) return v.Store(store)

View File

@ -0,0 +1,5 @@
package dns
func CheckUncleanShutdown(string) error {
return nil
}

View File

@ -0,0 +1,59 @@
//go:build !ios
package dns
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
)
const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns"
func CheckUncleanShutdown(string) error {
if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil {
if errors.Is(err, fs.ErrNotExist) {
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation)
manager, err := newHostManager()
if err != nil {
return fmt.Errorf("create host manager: %w", err)
}
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err)
}
return nil
}
func createUncleanShutdownIndicator() error {
dir := filepath.Dir(fileUncleanShutdownFileLocation)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
}
if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec
return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err)
}
return nil
}

View File

@ -0,0 +1,5 @@
package dns
func CheckUncleanShutdown(string) error {
return nil
}

View File

@ -0,0 +1,96 @@
//go:build !android
package dns
import (
"errors"
"fmt"
"io/fs"
"net/netip"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)
func CheckUncleanShutdown(wgIface string) error {
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
if errors.Is(err, fs.ErrNotExist) {
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation)
managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation)
if err != nil {
return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err)
}
managerFields := strings.Split(string(managerData), ",")
if len(managerFields) < 2 {
return errors.New("split manager data: insufficient number of fields")
}
osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1]
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
if err != nil {
return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err)
}
log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr)
// determine os manager type, so we can invoke the respective restore action
osManagerType, err := newOsManagerType(osManagerTypeStr)
if err != nil {
return fmt.Errorf("detect previous host manager: %w", err)
}
manager, err := newHostManagerFromType(wgIface, osManagerType)
if err != nil {
return fmt.Errorf("create previous host manager: %w", err)
}
if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err)
}
return nil
}
func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error {
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
}
if err := copyFile(sourcePath, fileUncleanShutdownResolvConfLocation); err != nil {
return fmt.Errorf("create %s: %w", sourcePath, err)
}
managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress)
if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec
return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err)
}
if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err)
}
return nil
}

View File

@ -0,0 +1,75 @@
package dns
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"github.com/sirupsen/logrus"
)
const (
netbirdProgramDataLocation = "Netbird"
fileUncleanShutdownFile = "unclean_shutdown_dns.txt"
)
func CheckUncleanShutdown(string) error {
file := getUncleanShutdownFile()
if _, err := os.Stat(file); err != nil {
if errors.Is(err, fs.ErrNotExist) {
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file)
guid, err := os.ReadFile(file)
if err != nil {
return fmt.Errorf("read %s: %w", file, err)
}
manager, err := newHostManagerWithGuid(string(guid))
if err != nil {
return fmt.Errorf("create host manager: %w", err)
}
if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err)
}
return nil
}
func createUncleanShutdownIndicator(guid string) error {
file := getUncleanShutdownFile()
dir := filepath.Dir(file)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
}
if err := os.WriteFile(file, []byte(guid), 0600); err != nil {
return fmt.Errorf("create %s: %w", file, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
file := getUncleanShutdownFile()
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", file, err)
}
return nil
}
func getUncleanShutdownFile() string {
return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile)
}

View File

@ -219,7 +219,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff()) log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff())
return fmt.Errorf("got an error from upstream check call") return fmt.Errorf("upstream check call error")
} }
err := backoff.Retry(operation, exponentialBackOff) err := backoff.Retry(operation, exponentialBackOff)

View File

@ -1080,6 +1080,11 @@ func (e *Engine) close() {
log.Errorf("failed closing ebpf proxy: %s", err) log.Errorf("failed closing ebpf proxy: %s", err)
} }
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
if e.dnsServer != nil {
e.dnsServer.Stop()
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil { if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil { if err := e.wgInterface.Close(); err != nil {
@ -1098,10 +1103,6 @@ func (e *Engine) close() {
e.routeManager.Stop() e.routeManager.Stop()
} }
if e.dnsServer != nil {
e.dnsServer.Stop()
}
if e.firewall != nil { if e.firewall != nil {
err := e.firewall.Reset() err := e.firewall.Reset()
if err != nil { if err != nil {

View File

@ -39,7 +39,6 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
return wgIFace, nil return wgIFace, nil
} }
// CreateOnAndroid this function make sense on mobile only // CreateOnAndroid this function make sense on mobile only