diff --git a/client/internal/dns/file_parser_unix.go b/client/internal/dns/file_parser_unix.go index 130c88214..6e123c94e 100644 --- a/client/internal/dns/file_parser_unix.go +++ b/client/internal/dns/file_parser_unix.go @@ -4,8 +4,8 @@ package dns import ( "fmt" + "net/netip" "os" - "regexp" "strings" log "github.com/sirupsen/logrus" @@ -15,9 +15,6 @@ const ( defaultResolvConfPath = "/etc/resolv.conf" ) -var timeoutRegex = regexp.MustCompile(`timeout:\d+`) -var attemptsRegex = regexp.MustCompile(`attempts:\d+`) - type resolvConf struct { nameServers []string searchDomains []string @@ -108,40 +105,9 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) { return rconf, nil } -// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist, -// otherwise it adds a new option with timeout and attempts. -func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string { - configs := make([]string, len(input)) - copy(configs, input) - - for i, config := range configs { - if strings.HasPrefix(config, "options") { - config = strings.ReplaceAll(config, "rotate", "") - config = strings.Join(strings.Fields(config), " ") - - if strings.Contains(config, "timeout:") { - config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout)) - } else { - config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1) - } - - if strings.Contains(config, "attempts:") { - config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts)) - } else { - config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1) - } - - configs[i] = config - return configs - } - } - - return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts)) -} - // removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position // and writes the file back to the original location -func removeFirstNbNameserver(filename, nameserverIP string) error { +func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error { resolvConf, err := parseResolvConfFile(filename) if err != nil { return fmt.Errorf("parse backup resolv.conf: %w", err) @@ -151,7 +117,7 @@ func removeFirstNbNameserver(filename, nameserverIP string) error { return fmt.Errorf("read %s: %w", filename, err) } - if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP { + if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() { newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1) stat, err := os.Stat(filename) diff --git a/client/internal/dns/file_parser_unix_test.go b/client/internal/dns/file_parser_unix_test.go index 1d6e64683..228a708f1 100644 --- a/client/internal/dns/file_parser_unix_test.go +++ b/client/internal/dns/file_parser_unix_test.go @@ -3,11 +3,13 @@ package dns import ( + "net/netip" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_parseResolvConf(t *testing.T) { @@ -175,52 +177,6 @@ nameserver 192.168.0.1 } } -func TestPrepareOptionsWithTimeout(t *testing.T) { - tests := []struct { - name string - others []string - timeout int - attempts int - expected []string - }{ - { - name: "Append new options with timeout and attempts", - others: []string{"some config"}, - timeout: 2, - attempts: 2, - expected: []string{"some config", "options timeout:2 attempts:2"}, - }, - { - name: "Modify existing options to exclude rotate and include timeout and attempts", - others: []string{"some config", "options rotate someother"}, - timeout: 3, - attempts: 2, - expected: []string{"some config", "options attempts:2 timeout:3 someother"}, - }, - { - name: "Existing options with timeout and attempts are updated", - others: []string{"some config", "options timeout:4 attempts:3"}, - timeout: 5, - attempts: 4, - expected: []string{"some config", "options timeout:5 attempts:4"}, - }, - { - name: "Modify existing options, add missing attempts before timeout", - others: []string{"some config", "options timeout:4"}, - timeout: 4, - attempts: 3, - expected: []string{"some config", "options attempts:3 timeout:4"}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts) - assert.Equal(t, tc.expected, result) - }) - } -} - func TestRemoveFirstNbNameserver(t *testing.T) { testCases := []struct { name string @@ -292,7 +248,9 @@ search localdomain`, err := os.WriteFile(tempFile, []byte(tc.content), 0644) assert.NoError(t, err) - err = removeFirstNbNameserver(tempFile, tc.ipToRemove) + ip, err := netip.ParseAddr(tc.ipToRemove) + require.NoError(t, err, "Failed to parse IP address") + err = removeFirstNbNameserver(tempFile, ip) assert.NoError(t, err) content, err := os.ReadFile(tempFile) diff --git a/client/internal/dns/file_repair_unix.go b/client/internal/dns/file_repair_unix.go index 9a9218fa1..75af411df 100644 --- a/client/internal/dns/file_repair_unix.go +++ b/client/internal/dns/file_repair_unix.go @@ -3,6 +3,7 @@ package dns import ( + "net/netip" "path" "path/filepath" "sync" @@ -22,7 +23,7 @@ var ( } ) -type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error +type repairConfFn func([]string, netip.Addr, *resolvConf, *statemanager.Manager) error type repair struct { operationFile string @@ -42,7 +43,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair { } } -func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) { +func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP netip.Addr, stateManager *statemanager.Manager) { if f.inotify != nil { return } @@ -136,7 +137,7 @@ func (f *repair) isEventRelevant(event fsnotify.Event) bool { // nbParamsAreMissing checks if the resolv.conf file contains all the parameters that NetBird needs // check the NetBird related nameserver IP at the first place // check the NetBird related search domains in the search domains list -func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP string, rConf *resolvConf) bool { +func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rConf *resolvConf) bool { if !isContains(nbSearchDomains, rConf.searchDomains) { return true } @@ -145,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP string, rConf *r return true } - if rConf.nameServers[0] != nbNameserverIP { + if rConf.nameServers[0] != nbNameserverIP.String() { return true } diff --git a/client/internal/dns/file_repair_unix_test.go b/client/internal/dns/file_repair_unix_test.go index 3aa0b859e..f22081307 100644 --- a/client/internal/dns/file_repair_unix_test.go +++ b/client/internal/dns/file_repair_unix_test.go @@ -4,6 +4,7 @@ package dns import ( "context" + "net/netip" "os" "path/filepath" "testing" @@ -105,14 +106,14 @@ nameserver 8.8.8.8`, var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { + updateFn := func([]string, netip.Addr, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil } r := newRepair(operationFile, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) + r.watchFileChanges([]string{"netbird.cloud"}, netip.MustParseAddr("10.0.0.1"), nil) err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755) if err != nil { @@ -152,14 +153,14 @@ searchdomain netbird.cloud something` var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { + updateFn := func([]string, netip.Addr, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil } r := newRepair(tmpLink, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) + r.watchFileChanges([]string{"netbird.cloud"}, netip.MustParseAddr("10.0.0.1"), nil) err = os.WriteFile(tmpLink, []byte(modifyContent), 0755) if err != nil { diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 3e338267f..423989f72 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -8,7 +8,6 @@ import ( "net/netip" "os" "strings" - "time" log "github.com/sirupsen/logrus" @@ -18,7 +17,7 @@ import ( const ( fileGeneratedResolvConfContentHeader = "# Generated by NetBird" fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + ` -# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n" +# The original file can be restored from ` + fileDefaultResolvConfBackupLocation + "\n\n" fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird" @@ -26,16 +25,11 @@ const ( fileMaxNumberOfSearchDomains = 6 ) -const ( - dnsFailoverTimeout = 4 * time.Second - dnsFailoverAttempts = 1 -) - type fileConfigurator struct { - repair *repair - - originalPerms os.FileMode - nbNameserverIP string + repair *repair + originalPerms os.FileMode + nbNameserverIP netip.Addr + originalNameservers []string } func newFileConfigurator() (*fileConfigurator, error) { @@ -49,22 +43,9 @@ func (f *fileConfigurator) supportCustomPort() bool { } func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - backupFileExist := f.isBackupFileExist() - if !config.RouteAll { - if backupFileExist { - f.repair.stopWatchFileChanges() - err := f.restore() - if err != nil { - return fmt.Errorf("restoring the original resolv.conf file return err: %w", err) - } - } - return ErrRouteAllWithoutNameserverGroup - } - - if !backupFileExist { - err := f.backup() - if err != nil { - return fmt.Errorf("unable to backup the resolv.conf file: %w", err) + if !f.isBackupFileExist() { + if err := f.backup(); err != nil { + return fmt.Errorf("backup resolv.conf: %w", err) } } @@ -76,6 +57,8 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err) } + f.originalNameservers = resolvConf.nameServers + f.repair.stopWatchFileChanges() err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager) @@ -86,15 +69,19 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st return nil } -func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error { - searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) - nameServers := generateNsList(nbNameserverIP, cfg) +// getOriginalNameservers returns the nameservers that were found in the original resolv.conf +func (f *fileConfigurator) getOriginalNameservers() []string { + return f.originalNameservers +} + +func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP netip.Addr, cfg *resolvConf, stateManager *statemanager.Manager) error { + searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) - options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts) buf := prepareResolvConfContent( searchDomainList, - nameServers, - options) + []string{nbNameserverIP.String()}, + cfg.others, + ) log.Debugf("creating managed file %s", defaultResolvConfPath) err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms) @@ -197,38 +184,28 @@ func restoreResolvConfFile() error { return nil } -// generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list -func generateNsList(nbNameserverIP string, cfg *resolvConf) []string { - ns := make([]string, 1, len(cfg.nameServers)+1) - ns[0] = nbNameserverIP - for _, cfgNs := range cfg.nameServers { - if nbNameserverIP != cfgNs { - ns = append(ns, cfgNs) - } - } - return ns -} - func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer { var buf bytes.Buffer + buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine) for _, cfgLine := range others { buf.WriteString(cfgLine) - buf.WriteString("\n") + buf.WriteByte('\n') } if len(searchDomains) > 0 { buf.WriteString("search ") buf.WriteString(strings.Join(searchDomains, " ")) - buf.WriteString("\n") + buf.WriteByte('\n') } for _, ns := range nameServers { buf.WriteString("nameserver ") buf.WriteString(ns) - buf.WriteString("\n") + buf.WriteByte('\n') } + return buf } diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 7e7e7cc2d..36da8fb78 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -15,6 +15,7 @@ const ( PriorityDNSRoute = 75 PriorityUpstream = 50 PriorityDefault = 1 + PriorityFallback = -100 ) type SubdomainMatcher interface { @@ -191,7 +192,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // No handler matched or all handlers passed log.Tracef("no handler found for domain=%s", qname) resp := &dns.Msg{} - resp.SetRcode(r, dns.RcodeNameError) + resp.SetRcode(r, dns.RcodeRefused) if err := w.WriteMsg(resp); err != nil { log.Errorf("failed to write DNS response: %v", err) } diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index dbf0f2cfc..fa474afde 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -11,8 +11,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) -var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") - const ( ipv4ReverseZone = ".in-addr.arpa." ipv6ReverseZone = ".ip6.arpa." @@ -27,14 +25,14 @@ type hostManager interface { type SystemDNSSettings struct { Domains []string - ServerIP string + ServerIP netip.Addr ServerPort int } type HostDNSConfig struct { Domains []DomainConfig `json:"domains"` RouteAll bool `json:"routeAll"` - ServerIP string `json:"serverIP"` + ServerIP netip.Addr `json:"serverIP"` ServerPort int `json:"serverPort"` } @@ -89,7 +87,7 @@ func newNoopHostMocker() hostManager { } } -func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostDNSConfig { +func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) HostDNSConfig { config := HostDNSConfig{ RouteAll: false, ServerIP: ip, diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index a445bc6c4..820cf9029 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -7,7 +7,7 @@ import ( "bytes" "fmt" "io" - "net" + "net/netip" "os/exec" "strconv" "strings" @@ -165,13 +165,13 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { } func (s *systemConfigurator) addLocalDNS() error { - if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { err := s.recordSystemDNSSettings(true) log.Errorf("Unable to get system DNS configuration") return err } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 { + if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) if err != nil { return fmt.Errorf("couldn't add local network DNS conf: %w", err) @@ -184,7 +184,7 @@ func (s *systemConfigurator) addLocalDNS() error { } func (s *systemConfigurator) recordSystemDNSSettings(force bool) error { - if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force { + if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 && !force { return nil } @@ -238,8 +238,8 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) } else if inServerAddressesArray { address := strings.Split(line, " : ")[1] - if ip := net.ParseIP(address); ip != nil && ip.To4() != nil { - dnsSettings.ServerIP = address + if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { + dnsSettings.ServerIP = ip inServerAddressesArray = false // Stop reading after finding the first IPv4 address } } @@ -250,12 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { } // default to 53 port - dnsSettings.ServerPort = 53 + dnsSettings.ServerPort = defaultPort return dnsSettings, nil } -func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { +func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error { err := s.addDNSState(key, domains, ip, port, true) if err != nil { return fmt.Errorf("add dns state: %w", err) @@ -268,7 +268,7 @@ func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, po return nil } -func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error { +func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error { err := s.addDNSState(key, domains, dnsServer, port, false) if err != nil { return fmt.Errorf("add dns state: %w", err) @@ -281,14 +281,14 @@ func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, por return nil } -func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port int, enableSearch bool) error { +func (s *systemConfigurator) addDNSState(state, domains string, dnsServer netip.Addr, port int, enableSearch bool) error { noSearch := "1" if enableSearch { noSearch = "0" } lines := buildAddCommandLine(keySupplementalMatchDomains, arraySymbol+domains) lines += buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+noSearch) - lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer) + lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer.String()) lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) addDomainCommand := buildCreateStateWithOperation(state, lines) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index f8939328a..648a58207 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "net/netip" "os/exec" "strings" "syscall" @@ -210,8 +211,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } -func (r *registryConfigurator) addDNSSetupForAll(ip string) error { - if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip); err != nil { +func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { + if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) } r.routingAll = true @@ -219,7 +220,7 @@ func (r *registryConfigurator) addDNSSetupForAll(ip string) error { return nil } -func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error { +func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) error { // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 if r.gpo { @@ -241,7 +242,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er } // configureDNSPolicy handles the actual configuration of a DNS policy at the specified path -func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip string) error { +func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil { return fmt.Errorf("remove existing dns policy: %w", err) } @@ -260,7 +261,7 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("set %s: %w", dnsPolicyConfigNameKey, err) } - if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip); err != nil { + if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip.String()); err != nil { return fmt.Errorf("set %s: %w", dnsPolicyConfigGenericDNSServersKey, err) } diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index c5dd6e23f..40a2e7384 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -2,6 +2,7 @@ package dns import ( "fmt" + "net/netip" "github.com/miekg/dns" @@ -45,8 +46,8 @@ func (m *MockServer) Stop() { } } -func (m *MockServer) DnsIP() string { - return "" +func (m *MockServer) DnsIP() netip.Addr { + return netip.MustParseAddr("100.10.254.255") } func (m *MockServer) OnUpdatedHostDNSServer(strings []string) { diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index caae63a24..5459bc2d7 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -110,11 +110,7 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st connSettings.cleanDeprecatedSettings() - dnsIP, err := netip.ParseAddr(config.ServerIP) - if err != nil { - return fmt.Errorf("unable to parse ip address, error: %w", err) - } - convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice()) + convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice()) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) var ( searchDomains []string diff --git a/client/internal/dns/resolvconf_unix.go b/client/internal/dns/resolvconf_unix.go index 54c4c75bf..6080c1d2c 100644 --- a/client/internal/dns/resolvconf_unix.go +++ b/client/internal/dns/resolvconf_unix.go @@ -46,9 +46,9 @@ type resolvconf struct { func detectResolvconfType() (resolvconfType, error) { cmd := exec.Command(resolvconfCommand, "--version") - out, err := cmd.Output() + out, err := cmd.CombinedOutput() if err != nil { - return typeOpenresolv, fmt.Errorf("failed to determine resolvconf type: %w", err) + return typeOpenresolv, fmt.Errorf("determine resolvconf type: %w", err) } if strings.Contains(string(out), "openresolv") { @@ -66,7 +66,7 @@ func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) { implType, err := detectResolvconfType() if err != nil { log.Warnf("failed to detect resolvconf type, defaulting to openresolv: %v", err) - implType = typeOpenresolv + implType = typeResolvconf } else { log.Infof("detected resolvconf type: %v", implType) } @@ -85,24 +85,14 @@ func (r *resolvconf) supportCustomPort() bool { } func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - if !config.RouteAll { - err = r.restoreHostDNS() - if err != nil { - log.Errorf("restore host dns: %s", err) - } - return ErrRouteAllWithoutNameserverGroup - } - searchDomainList := searchDomains(config) searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains) - options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts) - buf := prepareResolvConfContent( searchDomainList, - append([]string{config.ServerIP}, r.originalNameServers...), - options) + []string{config.ServerIP.String()}, + r.othersConfigs, + ) state := &ShutdownState{ ManagerType: resolvConfManager, @@ -112,8 +102,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman log.Errorf("failed to update shutdown state: %s", err) } - err = r.applyConfig(buf) - if err != nil { + if err := r.applyConfig(buf); err != nil { return fmt.Errorf("apply config: %w", err) } @@ -121,6 +110,10 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman return nil } +func (r *resolvconf) getOriginalNameservers() []string { + return r.originalNameServers +} + func (r *resolvconf) restoreHostDNS() error { var cmd *exec.Cmd @@ -157,7 +150,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error { } cmd.Stdin = &content - out, err := cmd.Output() + out, err := cmd.CombinedOutput() log.Tracef("resolvconf output: %s", out) if err != nil { return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index e81aebf98..f933c1de0 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -2,7 +2,6 @@ package dns import ( "context" - "errors" "fmt" "net/netip" "runtime" @@ -20,7 +19,6 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" - cProto "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" ) @@ -41,7 +39,7 @@ type Server interface { DeregisterHandler(domains domain.List, priority int) Initialize() error Stop() - DnsIP() string + DnsIP() netip.Addr UpdateDNSServer(serial uint64, update nbdns.Config) error OnUpdatedHostDNSServer(strings []string) SearchDomains() []string @@ -53,6 +51,12 @@ type nsGroupsByDomain struct { groups []*nbdns.NameServerGroup } +// hostManagerWithOriginalNS extends the basic hostManager interface +type hostManagerWithOriginalNS interface { + hostManager + getOriginalNameservers() []string +} + // DefaultServer dns server object type DefaultServer struct { ctx context.Context @@ -215,6 +219,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p log.Warn("skipping empty domain") continue } + s.handlerChain.AddHandler(domain, handler, priority) } } @@ -286,7 +291,7 @@ func (s *DefaultServer) Initialize() (err error) { // // When kernel space interface used it return real DNS server listener IP address // For bind interface, fake DNS resolver address returned (second last IP address from Nebird network) -func (s *DefaultServer) DnsIP() string { +func (s *DefaultServer) DnsIP() netip.Addr { return s.service.RuntimeIP() } @@ -297,6 +302,11 @@ func (s *DefaultServer) Stop() { s.ctxCancel() if s.hostManager != nil { + if srvs, ok := s.hostManager.(hostManagerWithOriginalNS); ok && len(srvs.getOriginalNameservers()) > 0 { + log.Debugf("deregistering original nameservers as fallback handlers") + s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + } + if err := s.hostManager.restoreHostDNS(); err != nil { log.Error("failed to restore host DNS settings: ", err) } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil { @@ -311,7 +321,6 @@ func (s *DefaultServer) Stop() { // OnUpdatedHostDNSServer update the DNS servers addresses for root zones // It will be applied if the mgm server do not enforce DNS settings for root zone - func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { s.hostsDNSHolder.set(hostsDnsList) @@ -493,25 +502,56 @@ func (s *DefaultServer) applyHostConfig() { if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { log.Errorf("failed to apply DNS host manager update: %v", err) - s.handleErrNoGroupaAll(err) } + + s.registerFallback(config) } -func (s *DefaultServer) handleErrNoGroupaAll(err error) { - if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) { +// registerFallback registers original nameservers as low-priority fallback handlers +func (s *DefaultServer) registerFallback(config HostDNSConfig) { + hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) + if !ok { return } - if s.statusRecorder == nil { + originalNameservers := hostMgrWithNS.getOriginalNameservers() + if len(originalNameservers) == 0 { return } - s.statusRecorder.PublishEvent( - cProto.SystemEvent_WARNING, cProto.SystemEvent_DNS, - "The host dns manager does not support match domains", - "The host dns manager does not support match domains without a catch-all nameserver group.", - map[string]string{"manager": s.hostManager.string()}, + log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback) + + handler, err := newUpstreamResolver( + s.ctx, + s.wgInterface.Name(), + s.wgInterface.Address().IP, + s.wgInterface.Address().Network, + s.statusRecorder, + s.hostsDNSHolder, + nbdns.RootZone, ) + if err != nil { + log.Errorf("failed to create upstream resolver for original nameservers: %v", err) + return + } + + for _, ns := range originalNameservers { + if ns == config.ServerIP.String() { + log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) + continue + } + + ns = fmt.Sprintf("%s:%d", ns, defaultPort) + if ip, err := netip.ParseAddr(ns); err == nil && ip.Is6() { + ns = fmt.Sprintf("[%s]:%d", ns, defaultPort) + } + + handler.upstreamServers = append(handler.upstreamServers, ns) + } + handler.deactivate = func(error) { /* always active */ } + handler.reactivate = func() { /* always active */ } + + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) } func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) { @@ -588,14 +628,8 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts priority := basePriority - i - // Check if we're about to overlap with the next priority tier. - // This boundary check ensures that the priority of upstream handlers does not conflict - // with the default priority tier. By decrementing the priority for each handler, we avoid - // overlaps, but if the calculated priority falls into the default tier, we skip the remaining - // handlers to maintain the integrity of the priority system. - if basePriority == PriorityUpstream && priority <= PriorityDefault { - log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityUpstream-PriorityDefault) + // Check if we're about to overlap with the next priority tier + if s.leaksPriority(domainGroup, basePriority, priority) { break } @@ -648,6 +682,21 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai return muxUpdates, nil } +func (s *DefaultServer) leaksPriority(domainGroup nsGroupsByDomain, basePriority int, priority int) bool { + if basePriority == PriorityUpstream && priority <= PriorityDefault { + log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", + domainGroup.domain, PriorityUpstream-PriorityDefault) + return true + } + if basePriority == PriorityDefault && priority <= PriorityFallback { + log.Warnf("too many handlers for domain=%s, would overlap with fallback priority tier (diff=%d). Skipping remaining handlers", + domainGroup.domain, PriorityDefault-PriorityFallback) + return true + } + + return false +} + func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { // this will introduce a short period of time when the server is not able to handle DNS requests for _, existing := range s.dnsMuxMap { @@ -760,6 +809,12 @@ func (s *DefaultServer) upstreamCallbacks( } func (s *DefaultServer) addHostRootZone() { + hostDNSServers := s.hostsDNSHolder.get() + if len(hostDNSServers) == 0 { + log.Debug("no host DNS servers available, skipping root zone handler creation") + return + } + handler, err := newUpstreamResolver( s.ctx, s.wgInterface.Name(), @@ -775,7 +830,7 @@ func (s *DefaultServer) addHostRootZone() { } handler.upstreamServers = make([]string, 0) - for k := range s.hostsDNSHolder.get() { + for k := range hostDNSServers { handler.upstreamServers = append(handler.upstreamServers, k) } handler.deactivate = func(error) {} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 21a9e2f2d..3cab4517a 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -938,7 +938,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return wgIface, nil } -func newDnsResolver(ip string, port int) *net.Resolver { +func newDnsResolver(ip netip.Addr, port int) *net.Resolver { return &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -1047,7 +1047,7 @@ type mockService struct{} func (m *mockService) Listen() error { return nil } func (m *mockService) Stop() {} -func (m *mockService) RuntimeIP() string { return "127.0.0.1" } +func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") } func (m *mockService) RuntimePort() int { return 53 } func (m *mockService) RegisterMux(string, dns.Handler) {} func (m *mockService) DeregisterMux(string) {} diff --git a/client/internal/dns/service.go b/client/internal/dns/service.go index 523976e54..ab8238a61 100644 --- a/client/internal/dns/service.go +++ b/client/internal/dns/service.go @@ -1,6 +1,8 @@ package dns import ( + "net/netip" + "github.com/miekg/dns" ) @@ -14,5 +16,5 @@ type service interface { RegisterMux(domain string, handler dns.Handler) DeregisterMux(key string) RuntimePort() int - RuntimeIP() string + RuntimeIP() netip.Addr } diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index 72dc4bc6e..abd2f4f05 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -18,8 +18,11 @@ import ( const ( customPort = 5053 - defaultIP = "127.0.0.1" - customIP = "127.0.0.153" +) + +var ( + defaultIP = netip.MustParseAddr("127.0.0.1") + customIP = netip.MustParseAddr("127.0.0.153") ) type serviceViaListener struct { @@ -27,7 +30,7 @@ type serviceViaListener struct { dnsMux *dns.ServeMux customAddr *netip.AddrPort server *dns.Server - listenIP string + listenIP netip.Addr listenPort uint16 listenerIsRunning bool listenerFlagLock sync.Mutex @@ -65,6 +68,7 @@ func (s *serviceViaListener) Listen() error { log.Errorf("failed to eval runtime address: %s", err) return fmt.Errorf("eval listen address: %w", err) } + s.listenIP = s.listenIP.Unmap() s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort) log.Debugf("starting dns on %s", s.server.Addr) go func() { @@ -124,7 +128,7 @@ func (s *serviceViaListener) RuntimePort() int { } } -func (s *serviceViaListener) RuntimeIP() string { +func (s *serviceViaListener) RuntimeIP() netip.Addr { return s.listenIP } @@ -139,9 +143,9 @@ func (s *serviceViaListener) setListenerStatus(running bool) { // first check the 53 port availability on WG interface or lo, if not success // pick a random port on WG interface for eBPF, if not success // check the 5053 port availability on WG interface or lo without eBPF usage, -func (s *serviceViaListener) evalListenAddress() (string, uint16, error) { +func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) { if s.customAddr != nil { - return s.customAddr.Addr().String(), s.customAddr.Port(), nil + return s.customAddr.Addr(), s.customAddr.Port(), nil } ip, ok := s.testFreePort(defaultPort) @@ -152,7 +156,7 @@ func (s *serviceViaListener) evalListenAddress() (string, uint16, error) { ebpfSrv, port, ok := s.tryToUseeBPF() if ok { s.ebpfService = ebpfSrv - return s.wgInterface.Address().IP.String(), port, nil + return s.wgInterface.Address().IP, port, nil } ip, ok = s.testFreePort(customPort) @@ -160,15 +164,15 @@ func (s *serviceViaListener) evalListenAddress() (string, uint16, error) { return ip, customPort, nil } - return "", 0, fmt.Errorf("failed to find a free port for DNS server") + return netip.Addr{}, 0, fmt.Errorf("failed to find a free port for DNS server") } -func (s *serviceViaListener) testFreePort(port int) (string, bool) { - var ips []string +func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) { + var ips []netip.Addr if runtime.GOOS != "darwin" { - ips = []string{s.wgInterface.Address().IP.String(), defaultIP, customIP} + ips = []netip.Addr{s.wgInterface.Address().IP, defaultIP, customIP} } else { - ips = []string{defaultIP, customIP} + ips = []netip.Addr{defaultIP, customIP} } for _, ip := range ips { @@ -178,10 +182,10 @@ func (s *serviceViaListener) testFreePort(port int) (string, bool) { return ip, true } - return "", false + return netip.Addr{}, false } -func (s *serviceViaListener) tryToBind(ip string, port int) bool { +func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool { addrString := fmt.Sprintf("%s:%d", ip, port) udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) probeListener, err := net.ListenUDP("udp", udpAddr) @@ -224,7 +228,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) { } func (s *serviceViaListener) generateFreePort() (uint16, error) { - ok := s.tryToBind(s.wgInterface.Address().IP.String(), customPort) + ok := s.tryToBind(s.wgInterface.Address().IP, customPort) if ok { return customPort, nil } diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 226202cf7..9f55838bf 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -16,7 +16,7 @@ import ( type ServiceViaMemory struct { wgInterface WGIface dnsMux *dns.ServeMux - runtimeIP string + runtimeIP netip.Addr runtimePort int udpFilterHookID string listenerIsRunning bool @@ -32,7 +32,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { wgInterface: wgIface, dnsMux: dns.NewServeMux(), - runtimeIP: lastIP.String(), + runtimeIP: lastIP, runtimePort: defaultPort, } return s @@ -84,7 +84,7 @@ func (s *ServiceViaMemory) RuntimePort() int { return s.runtimePort } -func (s *ServiceViaMemory) RuntimeIP() string { +func (s *ServiceViaMemory) RuntimeIP() netip.Addr { return s.runtimeIP } @@ -121,10 +121,5 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { return true } - ip, err := netip.ParseAddr(s.runtimeIP) - if err != nil { - return "", fmt.Errorf("parse runtime ip: %w", err) - } - - return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil + return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil } diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 9040ed787..a58747d5b 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -89,21 +89,16 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { } func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - parsedIP, err := netip.ParseAddr(config.ServerIP) - if err != nil { - return fmt.Errorf("unable to parse ip address, error: %w", err) - } - ipAs4 := parsedIP.As4() defaultLinkInput := systemdDbusDNSInput{ Family: unix.AF_INET, - Address: ipAs4[:], + Address: config.ServerIP.AsSlice(), } - if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil { + if err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil { return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err) } // We don't support dnssec. On some machines this is default on so we explicitly set it to off - if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil { + if err := s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil { log.Warnf("failed to set DNSSEC to 'no': %v", err) } @@ -129,8 +124,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana } if config.RouteAll { - err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true) - if err != nil { + if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true); err != nil { return fmt.Errorf("set link as default dns router: %w", err) } domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ @@ -139,7 +133,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana }) log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) } else { - if err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil { + if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil { return fmt.Errorf("remove link as default dns router: %w", err) } } @@ -153,9 +147,8 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) - err = s.setDomainsForInterface(domainsInput) - if err != nil { - log.Error(err) + if err := s.setDomainsForInterface(domainsInput); err != nil { + log.Error("failed to set domains for interface: ", err) } if err := s.flushDNSCache(); err != nil { diff --git a/client/internal/dns/unclean_shutdown_unix.go b/client/internal/dns/unclean_shutdown_unix.go index fcf60c694..2e786f484 100644 --- a/client/internal/dns/unclean_shutdown_unix.go +++ b/client/internal/dns/unclean_shutdown_unix.go @@ -35,12 +35,7 @@ func (s *ShutdownState) Cleanup() error { } // TODO: move file contents to state manager -func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error { - dnsAddress, err := netip.ParseAddr(dnsAddressStr) - if err != nil { - return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err) - } - +func createUncleanShutdownIndicator(sourcePath string, dnsAddress netip.Addr, stateManager *statemanager.Manager) error { dir := filepath.Dir(fileUncleanShutdownResolvConfLocation) if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { return fmt.Errorf("create dir %s: %w", dir, err) diff --git a/client/internal/engine.go b/client/internal/engine.go index 079adf7e8..d2de5b3cc 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1550,7 +1550,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { func (e *Engine) wgInterfaceCreate() (err error) { switch runtime.GOOS { case "android": - err = e.wgInterface.CreateOnAndroid(e.routeManager.InitialRouteRange(), e.dnsServer.DnsIP(), e.dnsServer.SearchDomains()) + err = e.wgInterface.CreateOnAndroid(e.routeManager.InitialRouteRange(), e.dnsServer.DnsIP().String(), e.dnsServer.SearchDomains()) case "ios": e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr) err = e.wgInterface.Create()