diff --git a/client/cmd/service.go b/client/cmd/service.go index 7a6729850..18fe5d621 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -32,6 +32,7 @@ func newSVCConfig() *service.Config { Name: name, DisplayName: "Netbird", Description: "A WireGuard-based mesh network that connects your devices into a single private network.", + Option: make(service.KeyValue), } } diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 86439ad17..8efb5ee60 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "path/filepath" "runtime" "github.com/spf13/cobra" @@ -32,8 +33,13 @@ var installCmd = &cobra.Command{ } if managementURL != "" { - svcConfig.Arguments = append(svcConfig.Arguments, "--management-url") - svcConfig.Arguments = append(svcConfig.Arguments, managementURL) + svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL) + } + + if logFile != "console" { + svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile) + svcConfig.Option["LogOutput"] = true + svcConfig.Option["LogDirectory"] = filepath.Dir(logFile) } if runtime.GOOS == "linux" { diff --git a/client/internal/dns/dbus_linux.go b/client/internal/dns/dbus_linux.go new file mode 100644 index 000000000..0f6d4156a --- /dev/null +++ b/client/internal/dns/dbus_linux.go @@ -0,0 +1,41 @@ +package dns + +import ( + "context" + "github.com/godbus/dbus/v5" + log "github.com/sirupsen/logrus" + "time" +) + +const dbusDefaultFlag = 0 + +func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool { + obj, closeConn, err := getDbusObject(dest, path) + if err != nil { + return false + } + defer closeConn() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store() + return err == nil +} + +func getDbusObject(dest string, path dbus.ObjectPath) (dbus.BusObject, func(), error) { + conn, err := dbus.SystemBus() + if err != nil { + return nil, nil, err + } + obj := conn.Object(dest, path) + + closeFunc := func() { + closeErr := conn.Close() + if closeErr != nil { + log.Warnf("got an error closing dbus connection, err: %s", closeErr) + } + } + + return obj, closeFunc, nil +} diff --git a/client/internal/dns/file_linux.go b/client/internal/dns/file_linux.go new file mode 100644 index 000000000..ad6e3a37f --- /dev/null +++ b/client/internal/dns/file_linux.go @@ -0,0 +1,154 @@ +package dns + +import ( + "bytes" + "fmt" + log "github.com/sirupsen/logrus" + "os" +) + +const ( + fileGeneratedResolvConfContentHeader = "# Generated by NetBird" + fileGeneratedResolvConfSearchBeginContent = "search " + fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader + + "\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" + + fileGeneratedResolvConfSearchBeginContent + "%s\n" +) +const ( + fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird" + fileMaxLineCharsLimit = 256 + fileMaxNumberOfSearchDomains = 6 +) + +var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent) + +type fileConfigurator struct { + originalPerms os.FileMode +} + +func newFileConfigurator() (hostManager, error) { + return &fileConfigurator{}, nil +} + +func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error { + backupFileExist := false + _, err := os.Stat(fileDefaultResolvConfBackupLocation) + if err == nil { + backupFileExist = true + } + + if !config.routeAll { + if backupFileExist { + err = f.restore() + if err != nil { + return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %s", err) + } + } + return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group") + } + managerType, err := getOSDNSManagerType() + if err != nil { + return err + } + switch managerType { + case fileManager, netbirdManager: + if !backupFileExist { + err = f.backup() + if err != nil { + return fmt.Errorf("unable to backup the resolv.conf file") + } + } + default: + // todo improve this and maybe restart DNS manager from scratch + return fmt.Errorf("something happened and file manager is not your prefered host dns configurator, restart the agent") + } + + var searchDomains string + appendedDomains := 0 + for _, dConf := range config.domains { + if dConf.matchOnly { + continue + } + if appendedDomains >= fileMaxNumberOfSearchDomains { + // lets log all skipped domains + log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain) + continue + } + if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit { + // lets log all skipped domains + log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain) + continue + } + + searchDomains += " " + dConf.domain + appendedDomains++ + } + content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains) + err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms) + if err != nil { + err = f.restore() + if err != nil { + log.Errorf("attempt to restore default file failed with error: %s", err) + } + return err + } + log.Infof("created a NetBird managed %s file with your DNS settings", defaultResolvConfPath) + return nil +} + +func (f *fileConfigurator) restoreHostDNS() error { + return f.restore() +} + +func (f *fileConfigurator) backup() error { + stats, err := os.Stat(defaultResolvConfPath) + if err != nil { + return fmt.Errorf("got an error while checking stats for %s file. Error: %s", defaultResolvConfPath, err) + } + + f.originalPerms = stats.Mode() + + err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation) + if err != nil { + return fmt.Errorf("got error while backing up the %s file. Error: %s", defaultResolvConfPath, err) + } + return nil +} + +func (f *fileConfigurator) restore() error { + err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath) + if err != nil { + return fmt.Errorf("got error while restoring the %s file from %s. Error: %s", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) + } + + return os.RemoveAll(fileDefaultResolvConfBackupLocation) +} + +func writeDNSConfig(content, fileName string, permissions os.FileMode) error { + log.Debugf("creating managed file %s", fileName) + var buf bytes.Buffer + buf.WriteString(content) + err := os.WriteFile(fileName, buf.Bytes(), permissions) + if err != nil { + return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err) + } + return nil +} + +func copyFile(src, dest string) error { + stats, err := os.Stat(src) + if err != nil { + return fmt.Errorf("got an error while checking stats for %s file when copying it. Error: %s", src, err) + } + + bytesRead, err := os.ReadFile(src) + if err != nil { + return fmt.Errorf("got an error while reading the file %s file for copy. Error: %s", src, err) + } + + err = os.WriteFile(dest, bytesRead, stats.Mode()) + if err != nil { + return fmt.Errorf("got an writing the destination file %s for copy. Error: %s", dest, err) + } + return nil +} diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go new file mode 100644 index 000000000..c077e2032 --- /dev/null +++ b/client/internal/dns/host.go @@ -0,0 +1,79 @@ +package dns + +import ( + "fmt" + nbdns "github.com/netbirdio/netbird/dns" + "strings" +) + +type hostManager interface { + applyDNSConfig(config hostDNSConfig) error + restoreHostDNS() error +} + +type hostDNSConfig struct { + domains []domainConfig + routeAll bool + serverIP string + serverPort int +} + +type domainConfig struct { + domain string + matchOnly bool +} + +type mockHostConfigurator struct { + applyDNSConfigFunc func(config hostDNSConfig) error + restoreHostDNSFunc func() error +} + +func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error { + if m.applyDNSConfigFunc != nil { + return m.applyDNSConfigFunc(config) + } + return fmt.Errorf("method applyDNSSettings is not implemented") +} + +func (m *mockHostConfigurator) restoreHostDNS() error { + if m.restoreHostDNSFunc != nil { + return m.restoreHostDNSFunc() + } + return fmt.Errorf("method restoreHostDNS is not implemented") +} + +func newNoopHostMocker() hostManager { + return &mockHostConfigurator{ + applyDNSConfigFunc: func(config hostDNSConfig) error { return nil }, + restoreHostDNSFunc: func() error { return nil }, + } +} + +func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig { + config := hostDNSConfig{ + routeAll: false, + serverIP: ip, + serverPort: port, + } + for _, nsConfig := range dnsConfig.NameServerGroups { + if nsConfig.Primary { + config.routeAll = true + } + + for _, domain := range nsConfig.Domains { + config.domains = append(config.domains, domainConfig{ + domain: strings.TrimSuffix(domain, "."), + matchOnly: true, + }) + } + } + + for _, customZone := range dnsConfig.CustomZones { + config.domains = append(config.domains, domainConfig{ + domain: strings.TrimSuffix(customZone.Domain, "."), + matchOnly: false, + }) + } + + return config +} diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go new file mode 100644 index 000000000..546561d88 --- /dev/null +++ b/client/internal/dns/host_darwin.go @@ -0,0 +1,259 @@ +package dns + +import ( + "bufio" + "bytes" + "fmt" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" + "os/exec" + "strconv" + "strings" +) + +const ( + netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS" + globalIPv4State = "State:/Network/Global/IPv4" + primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS" + keySupplementalMatchDomains = "SupplementalMatchDomains" + keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch" + keyServerAddresses = "ServerAddresses" + keyServerPort = "ServerPort" + arraySymbol = "* " + digitSymbol = "# " + scutilPath = "/usr/sbin/scutil" + searchSuffix = "Search" + matchSuffix = "Match" +) + +type systemConfigurator struct { + // primaryServiceID primary interface in the system. AKA the interface with the default route + primaryServiceID string + createdKeys map[string]struct{} +} + +func newHostManager(_ *iface.WGIface) (hostManager, error) { + return &systemConfigurator{ + createdKeys: make(map[string]struct{}), + }, nil +} + +func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error { + var err error + + if config.routeAll { + err = s.addDNSSetupForAll(config.serverIP, config.serverPort) + if err != nil { + return err + } + } else if s.primaryServiceID != "" { + err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) + if err != nil { + return err + } + s.primaryServiceID = "" + log.Infof("removed %s:%d as main DNS resolver for this peer", config.serverIP, config.serverPort) + } + + var ( + searchDomains []string + matchDomains []string + ) + + for _, dConf := range config.domains { + if dConf.matchOnly { + matchDomains = append(matchDomains, dConf.domain) + continue + } + searchDomains = append(searchDomains, dConf.domain) + } + + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + if len(matchDomains) != 0 { + err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.serverIP, config.serverPort) + } else { + log.Infof("removing match domains from the system") + err = s.removeKeyFromSystemConfig(matchKey) + } + if err != nil { + return err + } + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + if len(searchDomains) != 0 { + err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.serverIP, config.serverPort) + } else { + log.Infof("removing search domains from the system") + err = s.removeKeyFromSystemConfig(searchKey) + } + if err != nil { + return err + } + + return nil +} + +func (s *systemConfigurator) restoreHostDNS() error { + lines := "" + for key := range s.createdKeys { + lines += buildRemoveKeyOperation(key) + keyType := "search" + if strings.Contains(key, matchSuffix) { + keyType = "match" + } + log.Infof("removing %s domains from system", keyType) + } + if s.primaryServiceID != "" { + lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) + log.Infof("restoring DNS resolver configuration for system") + } + _, err := runSystemConfigCommand(wrapCommand(lines)) + if err != nil { + log.Errorf("got an error while cleaning the system configuration: %s", err) + return err + } + + return nil +} + +func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { + line := buildRemoveKeyOperation(key) + _, err := runSystemConfigCommand(wrapCommand(line)) + if err != nil { + return err + } + + delete(s.createdKeys, key) + + return nil +} + +func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { + err := s.addDNSState(key, domains, ip, port, true) + if err != nil { + return err + } + + log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains) + + s.createdKeys[key] = struct{}{} + + return nil +} + +func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error { + err := s.addDNSState(key, domains, dnsServer, port, false) + if err != nil { + return err + } + + log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains) + + s.createdKeys[key] = struct{}{} + + return nil +} + +func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port int, enableSearch bool) error { + noSearch := "1" + if enableSearch { + noSearch = "0" + } + lines := buildAddCommandLine(keySupplementalMatchDomains, arraySymbol+domains) + lines += buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+noSearch) + lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer) + lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) + + addDomainCommand := buildCreateStateWithOperation(state, lines) + stdinCommands := wrapCommand(addDomainCommand) + + _, err := runSystemConfigCommand(stdinCommands) + if err != nil { + return fmt.Errorf("got error while applying state for domains %s, error: %s", domains, err) + } + return nil +} + +func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error { + primaryServiceKey := s.getPrimaryService() + if primaryServiceKey == "" { + return fmt.Errorf("couldn't find the primary service key") + } + + err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port) + if err != nil { + return err + } + log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port) + s.primaryServiceID = primaryServiceKey + return nil +} + +func (s *systemConfigurator) getPrimaryService() string { + line := buildCommandLine("show", globalIPv4State, "") + stdinCommands := wrapCommand(line) + b, err := runSystemConfigCommand(stdinCommands) + if err != nil { + log.Error("got error while sending the command: ", err) + return "" + } + scanner := bufio.NewScanner(bytes.NewReader(b)) + for scanner.Scan() { + text := scanner.Text() + if strings.Contains(text, "PrimaryService") { + return strings.TrimSpace(strings.Split(text, ":")[1]) + } + } + return "" +} + +func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error { + lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0)) + lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer) + lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) + addDomainCommand := buildCreateStateWithOperation(setupKey, lines) + stdinCommands := wrapCommand(addDomainCommand) + _, err := runSystemConfigCommand(stdinCommands) + if err != nil { + return fmt.Errorf("got error while applying dns setup, error: %s", err) + } + return nil +} + +func getKeyWithInput(format, key string) string { + return fmt.Sprintf(format, key) +} + +func buildAddCommandLine(key, value string) string { + return buildCommandLine("d.add", key, value) +} + +func buildCommandLine(action, key, value string) string { + return fmt.Sprintf("%s %s %s\n", action, key, value) +} + +func wrapCommand(commands string) string { + return fmt.Sprintf("open\n%s\nquit\n", commands) +} + +func buildRemoveKeyOperation(key string) string { + return fmt.Sprintf("remove %s\n", key) +} + +func buildCreateStateWithOperation(state, commands string) string { + return buildWriteStateOperation("set", state, commands) +} + +func buildWriteStateOperation(operation, state, commands string) string { + return fmt.Sprintf("d.init\n%s %s\n%s\nset %s\n", operation, state, commands, state) +} + +func runSystemConfigCommand(command string) ([]byte, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader(command) + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("got error while running system configuration command: \"%s\", error: %s", command, err) + } + return out, nil +} diff --git a/client/internal/dns/host_linux.go b/client/internal/dns/host_linux.go new file mode 100644 index 000000000..ffb5098c7 --- /dev/null +++ b/client/internal/dns/host_linux.go @@ -0,0 +1,75 @@ +package dns + +import ( + "bufio" + "fmt" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" + "os" + "strings" +) + +const ( + defaultResolvConfPath = "/etc/resolv.conf" +) + +const ( + netbirdManager osManagerType = iota + fileManager + networkManager + systemdManager + resolvConfManager +) + +type osManagerType int + +func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { + osManager, err := getOSDNSManagerType() + if err != nil { + return nil, err + } + + log.Debugf("discovered mode is: %d", osManager) + switch osManager { + case networkManager: + return newNetworkManagerDbusConfigurator(wgInterface) + case systemdManager: + return newSystemdDbusConfigurator(wgInterface) + default: + return newFileConfigurator() + } +} + +func getOSDNSManagerType() (osManagerType, error) { + + file, err := os.Open(defaultResolvConfPath) + if err != nil { + return 0, fmt.Errorf("unable to open %s for checking owner, got error: %s", defaultResolvConfPath, err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + text := scanner.Text() + if len(text) == 0 { + continue + } + if text[0] != '#' { + return fileManager, nil + } + if strings.Contains(text, fileGeneratedResolvConfContentHeader) { + return netbirdManager, nil + } + if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { + log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion()) + return networkManager, nil + } + if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) { + return systemdManager, nil + } + if strings.Contains(text, "resolvconf") { + return resolvConfManager, nil + } + } + return fileManager, nil +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go new file mode 100644 index 000000000..e3f6cf34c --- /dev/null +++ b/client/internal/dns/host_windows.go @@ -0,0 +1,260 @@ +package dns + +import ( + "fmt" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows/registry" + "strings" +) + +const ( + dnsPolicyConfigMatchPath = "SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters\\DnsPolicyConfig\\NetBird-Match" + dnsPolicyConfigVersionKey = "Version" + dnsPolicyConfigVersionValue = 2 + dnsPolicyConfigNameKey = "Name" + dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers" + dnsPolicyConfigConfigOptionsKey = "ConfigOptions" + dnsPolicyConfigConfigOptionsValue = 0x8 +) + +const ( + interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces" + interfaceConfigNameServerKey = "NameServer" + interfaceConfigSearchListKey = "SearchList" + tcpipParametersPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters" +) + +type registryConfigurator struct { + guid string + routingAll bool + existingSearchDomains []string +} + +func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { + guid, err := wgInterface.GetInterfaceGUIDString() + if err != nil { + return nil, err + } + return ®istryConfigurator{ + guid: guid, + }, nil +} + +func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error { + var err error + if config.routeAll { + err = r.addDNSSetupForAll(config.serverIP) + if err != nil { + return err + } + } else if r.routingAll { + err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey) + if err != nil { + return err + } + r.routingAll = false + log.Infof("removed %s as main DNS forwarder for this peer", config.serverIP) + } + + var ( + searchDomains []string + matchDomains []string + ) + + for _, dConf := range config.domains { + if !dConf.matchOnly { + searchDomains = append(searchDomains, dConf.domain) + } + matchDomains = append(matchDomains, "."+dConf.domain) + } + + if len(matchDomains) != 0 { + err = r.addDNSMatchPolicy(matchDomains, config.serverIP) + } else { + err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath) + } + if err != nil { + return err + } + + err = r.updateSearchDomains(searchDomains) + if err != nil { + return err + } + + return nil +} + +func (r *registryConfigurator) addDNSSetupForAll(ip string) error { + err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip) + if err != nil { + return fmt.Errorf("adding dns setup for all failed with error: %s", err) + } + r.routingAll = true + log.Infof("configured %s:53 as main DNS forwarder for this peer", ip) + return nil +} + +func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error { + _, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.QUERY_VALUE) + if err == nil { + err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath) + if err != nil { + return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err) + } + } + + regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err) + } + + err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue) + if err != nil { + return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigVersionKey, err) + } + + err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains) + if err != nil { + return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigNameKey, err) + } + + err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip) + if err != nil { + return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigGenericDNSServersKey, err) + } + + err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue) + if err != nil { + return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigConfigOptionsKey, err) + } + + log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains) + + return nil +} + +func (r *registryConfigurator) restoreHostDNS() error { + err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath) + if err != nil { + log.Error(err) + } + + return r.updateSearchDomains([]string{}) +} + +func (r *registryConfigurator) updateSearchDomains(domains []string) error { + value, err := getLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey) + if err != nil { + return fmt.Errorf("unable to get current search domains failed with error: %s", err) + } + + valueList := strings.Split(value, ",") + setExisting := false + if len(r.existingSearchDomains) == 0 { + r.existingSearchDomains = valueList + setExisting = true + } + + if len(domains) == 0 && setExisting { + log.Infof("added %d search domains to the registry. Domain list: %s", len(domains), domains) + return nil + } + + newList := append(r.existingSearchDomains, domains...) + + err = setLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey, strings.Join(newList, ",")) + if err != nil { + return fmt.Errorf("adding search domain failed with error: %s", err) + } + + log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains) + + return nil +} + +func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error { + regKey, err := r.getInterfaceRegistryKey() + if err != nil { + return err + } + defer regKey.Close() + + err = regKey.SetStringValue(key, value) + if err != nil { + return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %s", key, value, err) + } + + return nil +} + +func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey string) error { + regKey, err := r.getInterfaceRegistryKey() + if err != nil { + return err + } + defer regKey.Close() + + err = regKey.DeleteValue(propertyKey) + if err != nil { + return fmt.Errorf("deleting registry key %s for interface failed with error: %s", propertyKey, err) + } + + return nil +} + +func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) { + var regKey registry.Key + + regKeyPath := interfaceConfigPath + "\\" + r.guid + + regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE) + if err != nil { + return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err) + } + + return regKey, nil +} + +func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE) + if err == nil { + k.Close() + err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath) + if err != nil { + return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err) + } + } + return nil +} + +func getLocalMachineRegistryKeyStringValue(keyPath, key string) (string, error) { + regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE) + if err != nil { + return "", fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err) + } + defer regKey.Close() + + val, _, err := regKey.GetStringValue(key) + if err != nil { + return "", fmt.Errorf("getting %s value for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, keyPath, err) + } + + return val, nil +} + +func setLocalMachineRegistryKeyStringValue(keyPath, key, value string) error { + regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err) + } + defer regKey.Close() + + err = regKey.SetStringValue(key, value) + if err != nil { + return fmt.Errorf("setting %s value %s for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, value, keyPath, err) + } + + return nil +} diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index 741ab97b4..680fcc31a 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -1,6 +1,7 @@ package dns import ( + "fmt" "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" log "github.com/sirupsen/logrus" @@ -14,16 +15,16 @@ type localResolver struct { // ServeDNS handles a DNS request func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - log.Tracef("received question: %#v\n", r.Question[0]) - response := d.lookupRecord(r) - if response == nil { - log.Debugf("got empty response for question: %#v\n", r.Question[0]) - return - } - + log.Debugf("received question: %#v\n", r.Question[0]) replyMessage := &dns.Msg{} replyMessage.SetReply(r) - replyMessage.Answer = append(replyMessage.Answer, response) + replyMessage.RecursionAvailable = true + replyMessage.Rcode = dns.RcodeSuccess + + response := d.lookupRecord(r) + if response != nil { + replyMessage.Answer = append(replyMessage.Answer, response) + } err := w.WriteMsg(replyMessage) if err != nil { @@ -32,7 +33,8 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR { - record, found := d.records.Load(r.Question[0].Name) + question := r.Question[0] + record, found := d.records.Load(buildRecordKey(question.Name, question.Qclass, question.Qtype)) if !found { return nil } @@ -46,7 +48,10 @@ func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error { return err } - d.records.Store(fullRecord.Header().Name, fullRecord) + fullRecord.Header().Rdlength = record.Len() + + header := fullRecord.Header() + d.records.Store(buildRecordKey(header.Name, header.Class, header.Rrtype), fullRecord) return nil } @@ -54,3 +59,8 @@ func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error { func (d *localResolver) deleteRecord(recordKey string) { d.records.Delete(dns.Fqdn(recordKey)) } + +func buildRecordKey(name string, class, qType uint16) string { + key := fmt.Sprintf("%s_%d_%d", name, class, qType) + return key +} diff --git a/client/internal/dns/local_test.go b/client/internal/dns/local_test.go index 79a57881b..db69d9ad8 100644 --- a/client/internal/dns/local_test.go +++ b/client/internal/dns/local_test.go @@ -64,7 +64,7 @@ func TestLocalResolver_ServeDNS(t *testing.T) { resolver.ServeDNS(responseWriter, testCase.inputMSG) - if responseMSG == nil { + if responseMSG == nil || len(responseMSG.Answer) == 0 { if testCase.responseShouldBeNil { return } diff --git a/client/internal/dns/network_manager_linux.go b/client/internal/dns/network_manager_linux.go new file mode 100644 index 000000000..955d54923 --- /dev/null +++ b/client/internal/dns/network_manager_linux.go @@ -0,0 +1,295 @@ +package dns + +import ( + "context" + "encoding/binary" + "fmt" + "github.com/godbus/dbus/v5" + "github.com/hashicorp/go-version" + "github.com/miekg/dns" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" + "net/netip" + "regexp" + "time" +) + +const ( + networkManagerDest = "org.freedesktop.NetworkManager" + networkManagerDbusObjectNode = "/org/freedesktop/NetworkManager" + networkManagerDbusDNSManagerInterface = "org.freedesktop.NetworkManager.DnsManager" + networkManagerDbusDNSManagerObjectNode = networkManagerDbusObjectNode + "/DnsManager" + networkManagerDbusDNSManagerModeProperty = networkManagerDbusDNSManagerInterface + ".Mode" + networkManagerDbusDNSManagerRcManagerProperty = networkManagerDbusDNSManagerInterface + ".RcManager" + networkManagerDbusVersionProperty = "org.freedesktop.NetworkManager.Version" + networkManagerDbusGetDeviceByIPIfaceMethod = networkManagerDest + ".GetDeviceByIpIface" + networkManagerDbusDeviceInterface = "org.freedesktop.NetworkManager.Device" + networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection" + networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply" + networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete" + networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0 + networkManagerDbusIPv4Key = "ipv4" + networkManagerDbusIPv6Key = "ipv6" + networkManagerDbusDNSKey = "dns" + networkManagerDbusDNSSearchKey = "dns-search" + networkManagerDbusDNSPriorityKey = "dns-priority" + + // dns priority doc https://wiki.gnome.org/Projects/NetworkManager/DNS + networkManagerDbusPrimaryDNSPriority int32 = -500 + networkManagerDbusWithMatchDomainPriority int32 = 0 + networkManagerDbusSearchDomainOnlyPriority int32 = 50 + supportedNetworkManagerVersionConstraint = ">= 1.16, < 1.28" +) + +type networkManagerDbusConfigurator struct { + dbusLinkObject dbus.ObjectPath + routingAll bool +} + +// the types below are based on dbus specification, each field is mapped to a dbus type +// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types +// see https://networkmanager.dev/docs/api/latest/gdbus-org.freedesktop.NetworkManager.Device.html on Network Manager input types + +// networkManagerConnSettings maps to a (a{sa{sv}}) dbus output from GetAppliedConnection and input for Reapply methods +type networkManagerConnSettings map[string]map[string]dbus.Variant + +// networkManagerConfigVersion maps to a (t) dbus output from GetAppliedConnection and input for Reapply methods +type networkManagerConfigVersion uint64 + +// networkManagerConfigBehavior maps to a (u) dbus input for GetAppliedConnection and Reapply methods +type networkManagerConfigBehavior uint32 + +// cleanDeprecatedSettings cleans deprecated settings that still returned by +// the GetAppliedConnection methods but can't be reApplied +func (s networkManagerConnSettings) cleanDeprecatedSettings() { + for _, key := range []string{"addresses", "routes"} { + delete(s[networkManagerDbusIPv4Key], key) + delete(s[networkManagerDbusIPv6Key], key) + } +} + +func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { + obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) + if err != nil { + return nil, err + } + defer closeConn() + var s string + err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.GetName()).Store(&s) + if err != nil { + return nil, err + } + + log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.GetName()) + + return &networkManagerDbusConfigurator{ + dbusLinkObject: dbus.ObjectPath(s), + }, nil +} + +func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { + connSettings, configVersion, err := n.getAppliedConnectionSettings() + if err != nil { + return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err) + } + + connSettings.cleanDeprecatedSettings() + + dnsIP := netip.MustParseAddr(config.serverIP) + convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice()) + connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) + var ( + searchDomains []string + matchDomains []string + ) + for _, dConf := range config.domains { + if dConf.matchOnly { + matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain)) + continue + } + searchDomains = append(searchDomains, dns.Fqdn(dConf.domain)) + } + + newDomainList := append(searchDomains, matchDomains...) + + priority := networkManagerDbusSearchDomainOnlyPriority + switch { + case config.routeAll: + priority = networkManagerDbusPrimaryDNSPriority + newDomainList = append(newDomainList, "~.") + if !n.routingAll { + log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + } + case len(matchDomains) > 0: + priority = networkManagerDbusWithMatchDomainPriority + } + + if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll { + log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + n.routingAll = false + } + + connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) + connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) + + log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) + err = n.reApplyConnectionSettings(connSettings, configVersion) + if err != nil { + return fmt.Errorf("got an error while reapplying the connection with new settings, error: %s", err) + } + return nil +} + +func (n *networkManagerDbusConfigurator) restoreHostDNS() error { + // once the interface is gone network manager cleans all config associated with it + return n.deleteConnectionSettings() +} + +func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) { + obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject) + if err != nil { + return nil, 0, fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err) + } + defer closeConn() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + var ( + connSettings networkManagerConnSettings + configVersion networkManagerConfigVersion + ) + + err = obj.CallWithContext(ctx, networkManagerDbusDeviceGetAppliedConnectionMethod, dbusDefaultFlag, + networkManagerDbusDefaultBehaviorFlag).Store(&connSettings, &configVersion) + if err != nil { + return nil, 0, fmt.Errorf("got error while calling GetAppliedConnection method with context, err: %s", err) + } + + return connSettings, configVersion, nil +} + +func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings networkManagerConnSettings, configVersion networkManagerConfigVersion) error { + obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject) + if err != nil { + return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err) + } + defer closeConn() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + err = obj.CallWithContext(ctx, networkManagerDbusDeviceReapplyMethod, dbusDefaultFlag, + connSettings, configVersion, networkManagerDbusDefaultBehaviorFlag).Store() + if err != nil { + return fmt.Errorf("got error while calling ReApply method with context, err: %s", err) + } + + return nil +} + +func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error { + obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject) + if err != nil { + return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err) + } + defer closeConn() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + err = obj.CallWithContext(ctx, networkManagerDbusDeviceDeleteMethod, dbusDefaultFlag).Store() + if err != nil { + return fmt.Errorf("got error while calling delete method with context, err: %s", err) + } + + return nil +} + +func isNetworkManagerSupported() bool { + return isNetworkManagerSupportedVersion() && isNetworkManagerSupportedMode() +} + +func isNetworkManagerSupportedMode() bool { + var mode string + err := getNetworkManagerDNSProperty(networkManagerDbusDNSManagerModeProperty, &mode) + if err != nil { + log.Error(err) + return false + } + switch mode { + case "dnsmasq", "unbound", "systemd-resolved": + return true + default: + var rcManager string + err = getNetworkManagerDNSProperty(networkManagerDbusDNSManagerRcManagerProperty, &rcManager) + if err != nil { + log.Error(err) + return false + } + if rcManager == "unmanaged" { + return false + } + } + return true +} + +func getNetworkManagerDNSProperty(property string, store any) error { + obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusDNSManagerObjectNode) + if err != nil { + return fmt.Errorf("got error while attempting to retrieve the network manager dns manager object, error: %s", err) + } + defer closeConn() + + v, e := obj.GetProperty(property) + if e != nil { + return fmt.Errorf("got an error getting property %s: %v", property, e) + } + + return v.Store(store) +} + +func isNetworkManagerSupportedVersion() bool { + obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) + if err != nil { + log.Errorf("got error while attempting to get the network manager object, err: %s", err) + return false + } + + defer closeConn() + + value, err := obj.GetProperty(networkManagerDbusVersionProperty) + if err != nil { + log.Errorf("unable to retrieve network manager mode, got error: %s", err) + return false + } + versionValue, err := parseVersion(value.Value().(string)) + if err != nil { + return false + } + + constraints, err := version.NewConstraint(supportedNetworkManagerVersionConstraint) + if err != nil { + return false + } + + return constraints.Check(versionValue) +} + +func parseVersion(inputVersion string) (*version.Version, error) { + reg, err := regexp.Compile(version.SemverRegexpRaw) + if err != nil { + return nil, err + } + + if inputVersion == "" || !reg.MatchString(inputVersion) { + return nil, fmt.Errorf("couldn't parse the provided version: Not SemVer") + } + + verObj, err := version.NewVersion(inputVersion) + if err != nil { + return nil, err + } + + return verObj, nil +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 67f3788ea..91a38cd4a 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,14 +5,19 @@ import ( "fmt" "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/iface" log "github.com/sirupsen/logrus" + "net" + "net/netip" + "runtime" "sync" "time" ) const ( - port = 5053 - defaultIP = "0.0.0.0" + port = 53 + customPort = 5053 + defaultIP = "127.0.0.1" ) // Server is a dns server interface @@ -31,8 +36,12 @@ type DefaultServer struct { dnsMux *dns.ServeMux dnsMuxMap registrationMap localResolver *localResolver + wgInterface *iface.WGIface + hostManager hostManager updateSerial uint64 listenerIsRunning bool + runtimePort int + runtimeIP string } type registrationMap map[string]struct{} @@ -43,11 +52,15 @@ type muxUpdate struct { } // NewDefaultServer returns a new dns server -func NewDefaultServer(ctx context.Context) *DefaultServer { +func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface) (*DefaultServer, error) { mux := dns.NewServeMux() + listenIP := defaultIP + if runtime.GOOS != "darwin" && wgInterface != nil { + listenIP = wgInterface.GetAddress().IP.String() + } dnsServer := &dns.Server{ - Addr: fmt.Sprintf("%s:%d", defaultIP, port), + Addr: fmt.Sprintf("%s:%d", listenIP, port), Net: "udp", Handler: mux, UDPSize: 65535, @@ -55,7 +68,7 @@ func NewDefaultServer(ctx context.Context) *DefaultServer { ctx, stop := context.WithCancel(ctx) - return &DefaultServer{ + defaultServer := &DefaultServer{ ctx: ctx, stop: stop, server: dnsServer, @@ -64,18 +77,44 @@ func NewDefaultServer(ctx context.Context) *DefaultServer { localResolver: &localResolver{ registeredMap: make(registrationMap), }, + wgInterface: wgInterface, + runtimePort: port, + runtimeIP: listenIP, } + + hostmanager, err := newHostManager(wgInterface) + if err != nil { + return nil, err + } + defaultServer.hostManager = hostmanager + return defaultServer, err } // Start runs the listener in a go routine func (s *DefaultServer) Start() { - log.Debugf("starting dns on %s:%d", defaultIP, port) + s.runtimePort = port + udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(s.server.Addr)) + probeListener, err := net.ListenUDP("udp", udpAddr) + if err != nil { + log.Warnf("using a custom port for dns server") + s.runtimePort = customPort + s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, customPort) + } else { + err = probeListener.Close() + if err != nil { + log.Errorf("got an error closing the probe listener, error: %s", err) + } + } + + log.Debugf("starting dns on %s", s.server.Addr) + go func() { s.setListenerStatus(true) defer s.setListenerStatus(false) - err := s.server.ListenAndServe() + + err = s.server.ListenAndServe() if err != nil { - log.Errorf("dns server returned an error: %v", err) + log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err) } }() } @@ -86,9 +125,16 @@ func (s *DefaultServer) setListenerStatus(running bool) { // Stop stops the server func (s *DefaultServer) Stop() { + s.mux.Lock() + defer s.mux.Unlock() s.stop() - err := s.stopListener() + err := s.hostManager.restoreHostDNS() + if err != nil { + log.Error(err) + } + + err = s.stopListener() if err != nil { log.Error(err) } @@ -148,6 +194,11 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro s.updateMux(muxUpdates) s.updateLocalResolver(localRecords) + err = s.hostManager.applyDNSConfig(dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)) + if err != nil { + log.Error(err) + } + s.updateSerial = serial return nil @@ -170,7 +221,12 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) }) for _, record := range customZone.Records { - localRecords[record.Name] = record + var class uint16 = dns.ClassINET + if record.Class != nbdns.DefaultClass { + return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class) + } + key := buildRecordKey(record.Name, class, uint16(record.Type)) + localRecords[key] = record } } return muxUpdates, localRecords, nil diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 6bbfef507..b0b8cd1ec 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -3,6 +3,7 @@ package dns import ( "context" "fmt" + "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" "net" "net/netip" @@ -74,12 +75,12 @@ func TestUpdateDNSServer(t *testing.T) { }, }, expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}}, - expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, }, { name: "New Config Should Succeed", initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + initUpstreamMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ @@ -98,7 +99,7 @@ func TestUpdateDNSServer(t *testing.T) { }, }, expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}}, - expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, }, { name: "Smaller Config Serial Should Be Skipped", @@ -188,12 +189,14 @@ func TestUpdateDNSServer(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - ctx := context.Background() - dnsServer := NewDefaultServer(ctx) + dnsServer := getDefaultServerWithNoHostManager("127.0.0.1") + + dnsServer.hostManager = newNoopHostMocker() dnsServer.dnsMuxMap = testCase.initUpstreamMap dnsServer.localResolver.registeredMap = testCase.initLocalMap dnsServer.updateSerial = testCase.initSerial + // pretend we are running dnsServer.listenerIsRunning = true err := dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) @@ -230,13 +233,15 @@ func TestUpdateDNSServer(t *testing.T) { } func TestDNSServerStartStop(t *testing.T) { - ctx := context.Background() - dnsServer := NewDefaultServer(ctx) + dnsServer := getDefaultServerWithNoHostManager("127.0.0.1") + if runtime.GOOS == "windows" && os.Getenv("CI") == "true" { // todo review why this test is not working only on github actions workflows t.Skip("skipping test in Windows CI workflows.") } + dnsServer.hostManager = newNoopHostMocker() + dnsServer.Start() err := dnsServer.localResolver.registerRecord(zoneRecords[0]) @@ -276,10 +281,40 @@ func TestDNSServerStartStop(t *testing.T) { } dnsServer.Stop() - ctx, cancel := context.WithTimeout(ctx, time.Second*1) + ctx, cancel := context.WithTimeout(context.TODO(), time.Second*1) defer cancel() _, err = resolver.LookupHost(ctx, zoneRecords[0].Name) if err == nil { t.Fatalf("we should encounter an error when querying a stopped server") } } + +func getDefaultServerWithNoHostManager(ip string) *DefaultServer { + mux := dns.NewServeMux() + listenIP := defaultIP + if ip != "" { + listenIP = ip + } + + dnsServer := &dns.Server{ + Addr: fmt.Sprintf("%s:%d", ip, port), + Net: "udp", + Handler: mux, + UDPSize: 65535, + } + + ctx, stop := context.WithCancel(context.TODO()) + + return &DefaultServer{ + ctx: ctx, + stop: stop, + server: dnsServer, + dnsMux: mux, + dnsMuxMap: make(registrationMap), + localResolver: &localResolver{ + registeredMap: make(registrationMap), + }, + runtimePort: port, + runtimeIP: listenIP, + } +} diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go new file mode 100644 index 000000000..54a73968a --- /dev/null +++ b/client/internal/dns/systemd_linux.go @@ -0,0 +1,185 @@ +package dns + +import ( + "context" + "fmt" + "github.com/godbus/dbus/v5" + "github.com/miekg/dns" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + "net" + "net/netip" + "time" +) + +const ( + systemdDbusManagerInterface = "org.freedesktop.resolve1.Manager" + systemdResolvedDest = "org.freedesktop.resolve1" + systemdDbusObjectNode = "/org/freedesktop/resolve1" + systemdDbusGetLinkMethod = systemdDbusManagerInterface + ".GetLink" + systemdDbusFlushCachesMethod = systemdDbusManagerInterface + ".FlushCaches" + systemdDbusLinkInterface = "org.freedesktop.resolve1.Link" + systemdDbusRevertMethodSuffix = systemdDbusLinkInterface + ".Revert" + systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS" + systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" + systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" +) + +type systemdDbusConfigurator struct { + dbusLinkObject dbus.ObjectPath + routingAll bool +} + +// the types below are based on dbus specification, each field is mapped to a dbus type +// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types +// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types +// systemdDbusDNSInput maps to a (iay) dbus input for SetDNS method +type systemdDbusDNSInput struct { + Family int32 + Address []byte +} + +// systemdDbusLinkDomainsInput maps to a (sb) dbus input for SetDomains method +type systemdDbusLinkDomainsInput struct { + Domain string + MatchOnly bool +} + +func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { + iface, err := net.InterfaceByName(wgInterface.GetName()) + if err != nil { + return nil, err + } + + obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) + if err != nil { + return nil, err + } + defer closeConn() + + var s string + err = obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, iface.Index).Store(&s) + if err != nil { + return nil, err + } + + log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index) + + return &systemdDbusConfigurator{ + dbusLinkObject: dbus.ObjectPath(s), + }, nil +} + +func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { + parsedIP := netip.MustParseAddr(config.serverIP).As4() + defaultLinkInput := systemdDbusDNSInput{ + Family: unix.AF_INET, + Address: parsedIP[:], + } + err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}) + if err != nil { + return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.serverIP, config.serverPort, err) + } + + var ( + searchDomains []string + matchDomains []string + domainsInput []systemdDbusLinkDomainsInput + ) + for _, dConf := range config.domains { + domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ + Domain: dns.Fqdn(dConf.domain), + MatchOnly: dConf.matchOnly, + }) + + if dConf.matchOnly { + matchDomains = append(matchDomains, dConf.domain) + continue + } + searchDomains = append(searchDomains, dConf.domain) + } + + if config.routeAll { + log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true) + if err != nil { + return fmt.Errorf("setting link as default dns router, failed with error: %s", err) + } + domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ + Domain: nbdns.RootZone, + MatchOnly: true, + }) + s.routingAll = true + } else if s.routingAll { + log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + } + + log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) + err = s.setDomainsForInterface(domainsInput) + if err != nil { + log.Error(err) + } + return nil +} + +func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdDbusLinkDomainsInput) error { + err := s.callLinkMethod(systemdDbusSetDomainsMethodSuffix, domainsInput) + if err != nil { + return fmt.Errorf("setting domains configuration failed with error: %s", err) + } + return s.flushCaches() +} + +func (s *systemdDbusConfigurator) restoreHostDNS() error { + log.Infof("reverting link settings and flushing cache") + if !isDbusListenerRunning(systemdResolvedDest, s.dbusLinkObject) { + return nil + } + err := s.callLinkMethod(systemdDbusRevertMethodSuffix, nil) + if err != nil { + return fmt.Errorf("unable to revert link configuration, got error: %s", err) + } + return s.flushCaches() +} + +func (s *systemdDbusConfigurator) flushCaches() error { + obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) + if err != nil { + return fmt.Errorf("got error while attempting to retrieve the object %s, err: %s", systemdDbusObjectNode, err) + } + defer closeConn() + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + err = obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, dbusDefaultFlag).Store() + if err != nil { + return fmt.Errorf("got error while calling the FlushCaches method with context, err: %s", err) + } + + return nil +} + +func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error { + obj, closeConn, err := getDbusObject(systemdResolvedDest, s.dbusLinkObject) + if err != nil { + return fmt.Errorf("got error while attempting to retrieve the object, err: %s", err) + } + defer closeConn() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + if value != nil { + err = obj.CallWithContext(ctx, method, dbusDefaultFlag, value).Store() + } else { + err = obj.CallWithContext(ctx, method, dbusDefaultFlag).Store() + } + + if err != nil { + return fmt.Errorf("got error while calling command with context, err: %s", err) + } + + return nil +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index fcc8bc685..e2e61203c 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -21,7 +21,7 @@ type upstreamResolver struct { // ServeDNS handles a DNS request func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - log.Tracef("received an upstream question: %#v", r.Question[0]) + log.Debugf("received an upstream question: %#v", r.Question[0]) select { case <-u.parentCTX.Done(): diff --git a/client/internal/engine.go b/client/internal/engine.go index 82425e62a..7e64c003d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -139,7 +139,6 @@ func NewEngine( networkSerial: 0, sshServerFunc: nbssh.DefaultSSHServer, statusRecorder: statusRecorder, - dnsServer: dns.NewDefaultServer(ctx), } } @@ -261,6 +260,14 @@ func (e *Engine) Start() error { e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder) + if e.dnsServer == nil { + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface) + if err != nil { + return err + } + e.dnsServer = dnsServer + } + e.receiveSignalEvents() e.receiveManagementEvents() diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 56d9eb66f..9e80f144d 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -202,6 +202,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { }, nbstatus.NewRecorder()) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) + engine.dnsServer = &dns.MockServer{ + UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, + } type testCase struct { name string @@ -551,6 +554,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { } engine.routeManager = mockRouteManager + engine.dnsServer = &dns.MockServer{} defer func() { exitErr := engine.Stop() @@ -797,6 +801,7 @@ func TestEngine_MultiplePeers(t *testing.T) { t.Errorf("unable to create the engine for peer %d with error %v", j, err) return } + engine.dnsServer = &dns.MockServer{} mu.Lock() defer mu.Unlock() err = engine.Start() diff --git a/dns/dns.go b/dns/dns.go index a09e4b5df..16ebd1d96 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/miekg/dns" "golang.org/x/net/idna" + "net" "regexp" "strings" ) @@ -60,6 +61,30 @@ func (s SimpleRecord) String() string { return fmt.Sprintf("%s %d %s %s %s", fqdn, s.TTL, s.Class, dns.Type(s.Type).String(), s.RData) } +// Len returns the length of the RData field, based on its type +func (s SimpleRecord) Len() uint16 { + emptyString := s.RData == "" + switch s.Type { + case 1: + if emptyString { + return 0 + } + return net.IPv4len + case 5: + if emptyString || s.RData == "." { + return 1 + } + return uint16(len(s.RData) + 1) + case 28: + if emptyString { + return 0 + } + return net.IPv6len + default: + return 0 + } +} + // GetParsedDomainLabel returns a domain label with max 59 characters, // parsed for old Hosts.txt requirements, and converted to ASCII and lowercase func GetParsedDomainLabel(name string) (string, error) { diff --git a/go.mod b/go.mod index e6c0528d2..b7f92bc73 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/eko/gocache/v3 v3.1.1 github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 + github.com/godbus/dbus/v5 v5.1.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/hashicorp/go-version v1.6.0 github.com/libp2p/go-netroute v0.2.0 @@ -75,7 +76,6 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-stack/stack v1.8.0 // indirect - github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/gopacket v1.1.19 // indirect diff --git a/go.sum b/go.sum index 3a4338ff7..707d7c808 100644 --- a/go.sum +++ b/go.sum @@ -223,8 +223,9 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= -github.com/godbus/dbus/v5 v5.0.4 h1:9349emZab16e7zQvpmsbtjc18ykshndd8y2PG3sgJbA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= +github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= diff --git a/iface/iface.go b/iface/iface.go index bdfa78abb..d75c4db86 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -73,13 +73,15 @@ func parseAddress(address string) (WGAddress, error) { func (w *WGIface) Close() error { w.mu.Lock() defer w.mu.Unlock() - + if w.Interface == nil { + return nil + } err := w.Interface.Close() if err != nil { return err } - if runtime.GOOS == "darwin" { + if runtime.GOOS != "windows" { sockPath := "/var/run/wireguard/" + w.Name + ".sock" if _, statErr := os.Stat(sockPath); statErr == nil { statErr = os.Remove(sockPath) diff --git a/iface/iface_unix.go b/iface/iface_unix.go index 66d316997..ebac5d8a1 100644 --- a/iface/iface_unix.go +++ b/iface/iface_unix.go @@ -75,3 +75,8 @@ func (w *WGIface) UpdateAddr(newAddr string) error { w.Address = addr return w.assignAddr() } + +// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only +func (w *WGIface) GetInterfaceGUIDString() (string, error) { + return "", nil +} diff --git a/iface/iface_windows.go b/iface/iface_windows.go index d38cd3dc4..5c16916b9 100644 --- a/iface/iface_windows.go +++ b/iface/iface_windows.go @@ -58,6 +58,20 @@ func (w *WGIface) UpdateAddr(newAddr string) error { return w.assignAddr(luid) } +// GetInterfaceGUIDString returns an interface GUID string +func (w *WGIface) GetInterfaceGUIDString() (string, error) { + if w.Interface == nil { + return "", fmt.Errorf("interface has not been initialized yet") + } + windowsDevice := w.Interface.(*driver.Adapter) + luid := windowsDevice.LUID() + guid, err := luid.GUID() + if err != nil { + return "", err + } + return guid.String(), nil +} + // WireguardModuleIsLoaded check if we can load wireguard mod (linux only) func WireguardModuleIsLoaded() bool { return false