diff --git a/client/android/client.go b/client/android/client.go index 6924d333c..c05246569 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,6 +4,7 @@ package android import ( "context" + "slices" "sync" log "github.com/sirupsen/logrus" @@ -112,7 +113,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). @@ -138,7 +139,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) + return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) } // Stop the internal client and free the resources @@ -235,7 +236,7 @@ func (c *Client) OnUpdatedHostDNS(list *DNSList) error { return err } - dnsServer.OnUpdatedHostDNSServer(list.items) + dnsServer.OnUpdatedHostDNSServer(slices.Clone(list.items)) return nil } diff --git a/client/android/dns_list.go b/client/android/dns_list.go index 76b922220..4c3dff4cc 100644 --- a/client/android/dns_list.go +++ b/client/android/dns_list.go @@ -1,23 +1,34 @@ package android -import "fmt" +import ( + "fmt" + "net/netip" -// DNSList is a wrapper of []string + "github.com/netbirdio/netbird/client/internal/dns" +) + +// DNSList is a wrapper of []netip.AddrPort with default DNS port type DNSList struct { - items []string + items []netip.AddrPort } -// Add new DNS address to the collection -func (array *DNSList) Add(s string) { - array.items = append(array.items, s) +// Add new DNS address to the collection, returns error if invalid +func (array *DNSList) Add(s string) error { + addr, err := netip.ParseAddr(s) + if err != nil { + return fmt.Errorf("invalid DNS address: %s", s) + } + addrPort := netip.AddrPortFrom(addr.Unmap(), dns.DefaultPort) + array.items = append(array.items, addrPort) + return nil } -// Get return an element of the collection +// Get return an element of the collection as string func (array *DNSList) Get(i int) (string, error) { if i >= len(array.items) || i < 0 { return "", fmt.Errorf("out of range") } - return array.items[i], nil + return array.items[i].Addr().String(), nil } // Size return with the size of the collection diff --git a/client/android/dns_list_test.go b/client/android/dns_list_test.go index 93aea78a8..7cb7b33a1 100644 --- a/client/android/dns_list_test.go +++ b/client/android/dns_list_test.go @@ -3,20 +3,30 @@ package android import "testing" func TestDNSList_Get(t *testing.T) { - l := DNSList{ - items: make([]string, 1), + l := DNSList{} + + // Add a valid DNS address + err := l.Add("8.8.8.8") + if err != nil { + t.Errorf("unexpected error: %s", err) } - _, err := l.Get(0) + // Test getting valid index + addr, err := l.Get(0) if err != nil { t.Errorf("invalid error: %s", err) } + if addr != "8.8.8.8" { + t.Errorf("expected 8.8.8.8, got %s", addr) + } + // Test negative index _, err = l.Get(-1) if err == nil { t.Errorf("expected error but got nil") } + // Test out of bounds index _, err = l.Get(1) if err == nil { t.Errorf("expected error but got nil") diff --git a/client/internal/connect.go b/client/internal/connect.go index d6be285f9..523dcaf1f 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "net/netip" "runtime" "runtime/debug" "strings" @@ -70,7 +71,7 @@ func (c *ConnectClient) RunOnAndroid( tunAdapter device.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, - dnsAddresses []string, + dnsAddresses []netip.AddrPort, dnsReadyListener dns.ReadyListener, ) error { // in case of non Android os these variables will be nil diff --git a/client/internal/dns/file_parser_unix.go b/client/internal/dns/file_parser_unix.go index 6e123c94e..8dacb4e51 100644 --- a/client/internal/dns/file_parser_unix.go +++ b/client/internal/dns/file_parser_unix.go @@ -16,7 +16,7 @@ const ( ) type resolvConf struct { - nameServers []string + nameServers []netip.Addr searchDomains []string others []string } @@ -36,7 +36,7 @@ func parseBackupResolvConf() (*resolvConf, error) { func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) { rconf := &resolvConf{ searchDomains: make([]string, 0), - nameServers: make([]string, 0), + nameServers: make([]netip.Addr, 0), others: make([]string, 0), } @@ -94,7 +94,11 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) { if len(splitLines) != 2 { continue } - rconf.nameServers = append(rconf.nameServers, splitLines[1]) + if addr, err := netip.ParseAddr(splitLines[1]); err == nil { + rconf.nameServers = append(rconf.nameServers, addr.Unmap()) + } else { + log.Warnf("invalid nameserver address in resolv.conf: %s, skipping", splitLines[1]) + } continue } @@ -104,31 +108,3 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) { } return rconf, nil } - -// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position -// and writes the file back to the original location -func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error { - resolvConf, err := parseResolvConfFile(filename) - if err != nil { - return fmt.Errorf("parse backup resolv.conf: %w", err) - } - content, err := os.ReadFile(filename) - if err != nil { - return fmt.Errorf("read %s: %w", filename, err) - } - - if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() { - newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1) - - stat, err := os.Stat(filename) - if err != nil { - return fmt.Errorf("stat %s: %w", filename, err) - } - if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil { - return fmt.Errorf("write %s: %w", filename, err) - } - - } - - return nil -} diff --git a/client/internal/dns/file_parser_unix_test.go b/client/internal/dns/file_parser_unix_test.go index 228a708f1..31e0dd5a0 100644 --- a/client/internal/dns/file_parser_unix_test.go +++ b/client/internal/dns/file_parser_unix_test.go @@ -3,13 +3,9 @@ package dns import ( - "net/netip" "os" "path/filepath" "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func Test_parseResolvConf(t *testing.T) { @@ -99,9 +95,13 @@ options debug t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains) } - ok = compareLists(cfg.nameServers, testCase.expectedNS) + nsStrings := make([]string, len(cfg.nameServers)) + for i, ns := range cfg.nameServers { + nsStrings[i] = ns.String() + } + ok = compareLists(nsStrings, testCase.expectedNS) if !ok { - t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers) + t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, nsStrings) } ok = compareLists(cfg.others, testCase.expectedOther) @@ -177,86 +177,3 @@ nameserver 192.168.0.1 } } -func TestRemoveFirstNbNameserver(t *testing.T) { - testCases := []struct { - name string - content string - ipToRemove string - expected string - }{ - { - name: "Unrelated nameservers with comments and options", - content: `# This is a comment -options rotate -nameserver 1.1.1.1 -# Another comment -nameserver 8.8.4.4 -search example.com`, - ipToRemove: "9.9.9.9", - expected: `# This is a comment -options rotate -nameserver 1.1.1.1 -# Another comment -nameserver 8.8.4.4 -search example.com`, - }, - { - name: "First nameserver matches", - content: `search example.com -nameserver 9.9.9.9 -# oof, a comment -nameserver 8.8.4.4 -options attempts:5`, - ipToRemove: "9.9.9.9", - expected: `search example.com -# oof, a comment -nameserver 8.8.4.4 -options attempts:5`, - }, - { - name: "Target IP not the first nameserver", - // nolint:dupword - content: `# Comment about the first nameserver -nameserver 8.8.4.4 -# Comment before our target -nameserver 9.9.9.9 -options timeout:2`, - ipToRemove: "9.9.9.9", - // nolint:dupword - expected: `# Comment about the first nameserver -nameserver 8.8.4.4 -# Comment before our target -nameserver 9.9.9.9 -options timeout:2`, - }, - { - name: "Only nameserver matches", - content: `options debug -nameserver 9.9.9.9 -search localdomain`, - ipToRemove: "9.9.9.9", - expected: `options debug -nameserver 9.9.9.9 -search localdomain`, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tempDir := t.TempDir() - tempFile := filepath.Join(tempDir, "resolv.conf") - err := os.WriteFile(tempFile, []byte(tc.content), 0644) - assert.NoError(t, err) - - ip, err := netip.ParseAddr(tc.ipToRemove) - require.NoError(t, err, "Failed to parse IP address") - err = removeFirstNbNameserver(tempFile, ip) - assert.NoError(t, err) - - content, err := os.ReadFile(tempFile) - assert.NoError(t, err) - - assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.") - }) - } -} diff --git a/client/internal/dns/file_repair_unix.go b/client/internal/dns/file_repair_unix.go index 75af411df..0846dbf38 100644 --- a/client/internal/dns/file_repair_unix.go +++ b/client/internal/dns/file_repair_unix.go @@ -146,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rCon return true } - if rConf.nameServers[0] != nbNameserverIP.String() { + if rConf.nameServers[0] != nbNameserverIP { return true } diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 423989f72..45e621443 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -29,7 +29,7 @@ type fileConfigurator struct { repair *repair originalPerms os.FileMode nbNameserverIP netip.Addr - originalNameservers []string + originalNameservers []netip.Addr } func newFileConfigurator() (*fileConfigurator, error) { @@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st } // getOriginalNameservers returns the nameservers that were found in the original resolv.conf -func (f *fileConfigurator) getOriginalNameservers() []string { +func (f *fileConfigurator) getOriginalNameservers() []netip.Addr { return f.originalNameservers } @@ -128,20 +128,14 @@ func (f *fileConfigurator) backup() error { } func (f *fileConfigurator) restore() error { - err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP) - if err != nil { - log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err) - } - - err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath) - if err != nil { + if err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath); err != nil { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) } return os.RemoveAll(fileDefaultResolvConfBackupLocation) } -func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error { +func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress netip.Addr) error { resolvConf, err := parseDefaultResolvConf() if err != nil { return fmt.Errorf("parse current resolv.conf: %w", err) @@ -152,16 +146,9 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add return restoreResolvConfFile() } - currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0]) - // not a valid first nameserver -> restore - if err != nil { - log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err) - return restoreResolvConfFile() - } - // current address is still netbird's non-available dns address -> restore - // comparing parsed addresses only, to remove ambiguity - if currentDNSAddress.String() == storedDNSAddress.String() { + currentDNSAddress := resolvConf.nameServers[0] + if currentDNSAddress == storedDNSAddress { return restoreResolvConfFile() } diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 820cf9029..852dfef48 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -239,7 +239,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { } else if inServerAddressesArray { address := strings.Split(line, " : ")[1] if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { - dnsSettings.ServerIP = ip + dnsSettings.ServerIP = ip.Unmap() inServerAddressesArray = false // Stop reading after finding the first IPv4 address } } @@ -250,7 +250,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { } // default to 53 port - dnsSettings.ServerPort = defaultPort + dnsSettings.ServerPort = DefaultPort return dnsSettings, nil } diff --git a/client/internal/dns/host_unix.go b/client/internal/dns/host_unix.go index 297d50822..422fed4e5 100644 --- a/client/internal/dns/host_unix.go +++ b/client/internal/dns/host_unix.go @@ -42,7 +42,7 @@ func (t osManagerType) String() string { type restoreHostManager interface { hostManager - restoreUncleanShutdownDNS(*netip.Addr) error + restoreUncleanShutdownDNS(netip.Addr) error } func newHostManager(wgInterface string) (hostManager, error) { @@ -130,8 +130,9 @@ func checkStub() bool { return true } + systemdResolvedAddr := netip.AddrFrom4([4]byte{127, 0, 0, 53}) // 127.0.0.53 for _, ns := range rConf.nameServers { - if ns == "127.0.0.53" { + if ns == systemdResolvedAddr { return true } } diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 648a58207..829d83a04 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -216,7 +216,7 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { return fmt.Errorf("adding dns setup for all failed: %w", err) } r.routingAll = true - log.Infof("configured %s:53 as main DNS forwarder for this peer", ip) + log.Infof("configured %s:%d as main DNS forwarder for this peer", ip, DefaultPort) return nil } diff --git a/client/internal/dns/hosts_dns_holder.go b/client/internal/dns/hosts_dns_holder.go index 2601af9c8..980d917a7 100644 --- a/client/internal/dns/hosts_dns_holder.go +++ b/client/internal/dns/hosts_dns_holder.go @@ -1,38 +1,31 @@ package dns import ( - "fmt" "net/netip" "sync" - - log "github.com/sirupsen/logrus" ) type hostsDNSHolder struct { - unprotectedDNSList map[string]struct{} + unprotectedDNSList map[netip.AddrPort]struct{} mutex sync.RWMutex } func newHostsDNSHolder() *hostsDNSHolder { return &hostsDNSHolder{ - unprotectedDNSList: make(map[string]struct{}), + unprotectedDNSList: make(map[netip.AddrPort]struct{}), } } -func (h *hostsDNSHolder) set(list []string) { +func (h *hostsDNSHolder) set(list []netip.AddrPort) { h.mutex.Lock() - h.unprotectedDNSList = make(map[string]struct{}) - for _, dns := range list { - dnsAddr, err := h.normalizeAddress(dns) - if err != nil { - continue - } - h.unprotectedDNSList[dnsAddr] = struct{}{} + h.unprotectedDNSList = make(map[netip.AddrPort]struct{}) + for _, addrPort := range list { + h.unprotectedDNSList[addrPort] = struct{}{} } h.mutex.Unlock() } -func (h *hostsDNSHolder) get() map[string]struct{} { +func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} { h.mutex.RLock() l := h.unprotectedDNSList h.mutex.RUnlock() @@ -40,24 +33,10 @@ func (h *hostsDNSHolder) get() map[string]struct{} { } //nolint:unused -func (h *hostsDNSHolder) isContain(upstream string) bool { +func (h *hostsDNSHolder) contains(upstream netip.AddrPort) bool { h.mutex.RLock() defer h.mutex.RUnlock() _, ok := h.unprotectedDNSList[upstream] return ok } - -func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) { - a, err := netip.ParseAddr(addr) - if err != nil { - log.Errorf("invalid upstream IP address: %s, error: %s", addr, err) - return "", err - } - - if a.Is4() { - return fmt.Sprintf("%s:53", addr), nil - } else { - return fmt.Sprintf("[%s]:53", addr), nil - } -} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 9e0d06399..d160fa99a 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -50,7 +50,7 @@ func (m *MockServer) DnsIP() netip.Addr { return netip.MustParseAddr("100.10.254.255") } -func (m *MockServer) OnUpdatedHostDNSServer(strings []string) { +func (m *MockServer) OnUpdatedHostDNSServer(addrs []netip.AddrPort) { // TODO implement me panic("implement me") } diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 5459bc2d7..e4ccc8cbd 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -245,7 +245,7 @@ func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error { return nil } -func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { +func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error { if err := n.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via network-manager: %w", err) } diff --git a/client/internal/dns/resolvconf_unix.go b/client/internal/dns/resolvconf_unix.go index 6080c1d2c..8cdea562b 100644 --- a/client/internal/dns/resolvconf_unix.go +++ b/client/internal/dns/resolvconf_unix.go @@ -40,7 +40,7 @@ type resolvconf struct { implType resolvconfType originalSearchDomains []string - originalNameServers []string + originalNameServers []netip.Addr othersConfigs []string } @@ -110,7 +110,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman return nil } -func (r *resolvconf) getOriginalNameservers() []string { +func (r *resolvconf) getOriginalNameservers() []netip.Addr { return r.originalNameServers } @@ -158,7 +158,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error { return nil } -func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error { +func (r *resolvconf) restoreUncleanShutdownDNS(netip.Addr) error { if err := r.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err) } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index cd0757dc1..ce75369c8 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -42,7 +42,7 @@ type Server interface { Stop() DnsIP() netip.Addr UpdateDNSServer(serial uint64, update nbdns.Config) error - OnUpdatedHostDNSServer(strings []string) + OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string ProbeAvailability() } @@ -55,7 +55,7 @@ type nsGroupsByDomain struct { // hostManagerWithOriginalNS extends the basic hostManager interface type hostManagerWithOriginalNS interface { hostManager - getOriginalNameservers() []string + getOriginalNameservers() []netip.Addr } // DefaultServer dns server object @@ -136,7 +136,7 @@ func NewDefaultServer( func NewDefaultServerPermanentUpstream( ctx context.Context, wgInterface WGIface, - hostsDnsList []string, + hostsDnsList []netip.AddrPort, config nbdns.Config, listener listener.NetworkChangeListener, statusRecorder *peer.Status, @@ -144,6 +144,7 @@ func NewDefaultServerPermanentUpstream( ) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) + ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true ds.addHostRootZone() @@ -340,7 +341,7 @@ func (s *DefaultServer) disableDNS() error { // OnUpdatedHostDNSServer update the DNS servers addresses for root zones // It will be applied if the mgm server do not enforce DNS settings for root zone -func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { +func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) { s.hostsDNSHolder.set(hostsDnsList) // Check if there's any root handler @@ -461,7 +462,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) - if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { + if s.service.RuntimePort() != DefaultPort && !s.hostManager.supportCustomPort() { log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver") s.currentConfig.RouteAll = false @@ -581,14 +582,13 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { } for _, ns := range originalNameservers { - if ns == config.ServerIP.String() { + if ns == config.ServerIP { log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) continue } - ns = formatAddr(ns, defaultPort) - - handler.upstreamServers = append(handler.upstreamServers, ns) + addrPort := netip.AddrPortFrom(ns, DefaultPort) + handler.upstreamServers = append(handler.upstreamServers, addrPort) } handler.deactivate = func(error) { /* always active */ } handler.reactivate = func() { /* always active */ } @@ -695,7 +695,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) continue } - handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns)) + handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort()) } if len(handler.upstreamServers) == 0 { @@ -770,18 +770,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { s.dnsMuxMap = muxUpdateMap } -func getNSHostPort(ns nbdns.NameServer) string { - return formatAddr(ns.IP.String(), ns.Port) -} - -// formatAddr formats a nameserver address with port, handling IPv6 addresses properly -func formatAddr(address string, port int) string { - if ip, err := netip.ParseAddr(address); err == nil && ip.Is6() { - return fmt.Sprintf("[%s]:%d", address, port) - } - return fmt.Sprintf("%s:%d", address, port) -} - // upstreamCallbacks returns two functions, the first one is used to deactivate // the upstream resolver from the configuration, the second one is used to // reactivate it. Not allowed to call reactivate before deactivate. @@ -879,10 +867,7 @@ func (s *DefaultServer) addHostRootZone() { return } - handler.upstreamServers = make([]string, 0) - for k := range hostDNSServers { - handler.upstreamServers = append(handler.upstreamServers, k) - } + handler.upstreamServers = maps.Keys(hostDNSServers) handler.deactivate = func(error) {} handler.reactivate = func() {} @@ -893,9 +878,9 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { var states []peer.NSGroupState for _, group := range groups { - var servers []string + var servers []netip.AddrPort for _, ns := range group.NameServers { - servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) + servers = append(servers, ns.AddrPort()) } state := peer.NSGroupState{ @@ -927,7 +912,7 @@ func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { var servers []string for _, ns := range nsGroup.NameServers { - servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) + servers = append(servers, ns.AddrPort().String()) } return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 3ac1b6eb1..91543da8f 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -97,9 +97,9 @@ func init() { } func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { - var srvs []string + var srvs []netip.AddrPort for _, srv := range servers { - srvs = append(srvs, getNSHostPort(srv)) + srvs = append(srvs, srv.AddrPort()) } return &upstreamResolverBase{ domain: domain, @@ -705,7 +705,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { } defer wgIFace.Close() - var dnsList []string + var dnsList []netip.AddrPort dnsConfig := nbdns.Config{} dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false) err = dnsServer.Initialize() @@ -715,7 +715,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { } defer dnsServer.Stop() - dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"}) + addrPort := netip.MustParseAddrPort("8.8.8.8:53") + dnsServer.OnUpdatedHostDNSServer([]netip.AddrPort{addrPort}) resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort()) _, err = resolver.LookupHost(context.Background(), "netbird.io") @@ -731,7 +732,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) { } defer wgIFace.Close() dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) + addrPort := netip.MustParseAddrPort("8.8.8.8:53") + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -823,7 +825,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) { } defer wgIFace.Close() dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) + addrPort := netip.MustParseAddrPort("8.8.8.8:53") + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -2053,56 +2056,3 @@ func TestLocalResolverPriorityConstants(t *testing.T) { assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) } - -func TestFormatAddr(t *testing.T) { - tests := []struct { - name string - address string - port int - expected string - }{ - { - name: "IPv4 address", - address: "8.8.8.8", - port: 53, - expected: "8.8.8.8:53", - }, - { - name: "IPv4 address with custom port", - address: "1.1.1.1", - port: 5353, - expected: "1.1.1.1:5353", - }, - { - name: "IPv6 address", - address: "fd78:94bf:7df8::1", - port: 53, - expected: "[fd78:94bf:7df8::1]:53", - }, - { - name: "IPv6 address with custom port", - address: "2001:db8::1", - port: 5353, - expected: "[2001:db8::1]:5353", - }, - { - name: "IPv6 localhost", - address: "::1", - port: 53, - expected: "[::1]:53", - }, - { - name: "Invalid address treated as hostname", - address: "dns.example.com", - port: 53, - expected: "dns.example.com:53", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := formatAddr(tt.address, tt.port) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/client/internal/dns/service.go b/client/internal/dns/service.go index ab8238a61..6a76c53e3 100644 --- a/client/internal/dns/service.go +++ b/client/internal/dns/service.go @@ -7,7 +7,7 @@ import ( ) const ( - defaultPort = 53 + DefaultPort = 53 ) type service interface { diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index abd2f4f05..806559444 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -122,7 +122,7 @@ func (s *serviceViaListener) RuntimePort() int { defer s.listenerFlagLock.Unlock() if s.ebpfService != nil { - return defaultPort + return DefaultPort } else { return int(s.listenPort) } @@ -148,9 +148,9 @@ func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) { return s.customAddr.Addr(), s.customAddr.Port(), nil } - ip, ok := s.testFreePort(defaultPort) + ip, ok := s.testFreePort(DefaultPort) if ok { - return ip, defaultPort, nil + return ip, DefaultPort, nil } ebpfSrv, port, ok := s.tryToUseeBPF() diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 9f55838bf..89d637686 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { dnsMux: dns.NewServeMux(), runtimeIP: lastIP, - runtimePort: defaultPort, + runtimePort: DefaultPort, } return s } diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index a58747d5b..0e8a53a63 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -235,7 +235,7 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error return nil } -func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { +func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error { if err := s.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via systemd: %w", err) } diff --git a/client/internal/dns/unclean_shutdown_unix.go b/client/internal/dns/unclean_shutdown_unix.go index 2e786f484..dc44aefaf 100644 --- a/client/internal/dns/unclean_shutdown_unix.go +++ b/client/internal/dns/unclean_shutdown_unix.go @@ -27,7 +27,7 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create previous host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil { + if err := manager.restoreUncleanShutdownDNS(s.DNSAddress); err != nil { return fmt.Errorf("restore unclean shutdown dns: %w", err) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index c44d36599..f5d0e775f 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "net" + "net/netip" "slices" "strings" "sync" @@ -48,7 +49,7 @@ type upstreamResolverBase struct { ctx context.Context cancel context.CancelFunc upstreamClient upstreamClient - upstreamServers []string + upstreamServers []netip.AddrPort domain string disabled bool failsCount atomic.Int32 @@ -79,17 +80,20 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("upstream %v", u.upstreamServers) + return fmt.Sprintf("upstream %s", u.upstreamServers) } // ID returns the unique handler ID func (u *upstreamResolverBase) ID() types.HandlerID { servers := slices.Clone(u.upstreamServers) - slices.Sort(servers) + slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) }) hash := sha256.New() hash.Write([]byte(u.domain + ":")) - hash.Write([]byte(strings.Join(servers, ","))) + for _, s := range servers { + hash.Write([]byte(s.String())) + hash.Write([]byte("|")) + } return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) } @@ -130,7 +134,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func() { ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) defer cancel() - rm, t, err = u.upstreamClient.exchange(ctx, upstream, r) + rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) }() if err != nil { @@ -197,7 +201,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) { proto.SystemEvent_DNS, "All upstream servers failed (fail count exceeded)", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", - map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, + map[string]string{"upstreams": u.upstreamServersString()}, // TODO add domain meta ) } @@ -258,7 +262,7 @@ func (u *upstreamResolverBase) ProbeAvailability() { proto.SystemEvent_DNS, "All upstream servers failed (probe failed)", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", - map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, + map[string]string{"upstreams": u.upstreamServersString()}, ) } } @@ -278,7 +282,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { operation := func() error { select { case <-u.ctx.Done(): - return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers)) + return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString())) default: } @@ -291,7 +295,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { } } - log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff()) + log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff()) return fmt.Errorf("upstream check call error") } @@ -301,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { return } - log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) + log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) u.failsCount.Store(0) u.successCount.Add(1) u.reactivate() @@ -331,13 +335,21 @@ func (u *upstreamResolverBase) disable(err error) { go u.waitUntilResponse() } -func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error { +func (u *upstreamResolverBase) upstreamServersString() string { + var servers []string + for _, server := range u.upstreamServers { + servers = append(servers, server.String()) + } + return strings.Join(servers, ", ") +} + +func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error { ctx, cancel := context.WithTimeout(u.ctx, timeout) defer cancel() r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) - _, _, err := u.upstreamClient.exchange(ctx, server, r) + _, _, err := u.upstreamClient.exchange(ctx, server.String(), r) return err } diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index e7db581b1..ddbf84ae4 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -79,8 +79,8 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri } func (u *upstreamResolver) isLocalResolver(upstream string) bool { - if u.hostsDNSHolder.isContain(upstream) { - return true + if addrPort, err := netip.ParseAddrPort(upstream); err == nil { + return u.hostsDNSHolder.contains(addrPort) } return false } diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 648cab176..96b8bbb0f 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -62,6 +62,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * upstreamIP, err := netip.ParseAddr(upstreamHost) if err != nil { log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err) + } else { + upstreamIP = upstreamIP.Unmap() } if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() { log.Debugf("using private client to query upstream: %s", upstream) diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index e440995d9..51d870e2a 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -59,7 +59,14 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".") - resolver.upstreamServers = testCase.InputServers + // Convert test servers to netip.AddrPort + var servers []netip.AddrPort + for _, server := range testCase.InputServers { + if addrPort, err := netip.ParseAddrPort(server); err == nil { + servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())) + } + } + resolver.upstreamServers = servers resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { cancel() @@ -128,7 +135,8 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { reactivatePeriod: reactivatePeriod, failsTillDeact: failsTillDeact, } - resolver.upstreamServers = []string{"0.0.0.0:-1"} + addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection + resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} resolver.failsTillDeact = 0 resolver.reactivatePeriod = time.Microsecond * 100 diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 4ac0fc141..7c95e2b99 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -1,6 +1,8 @@ package internal import ( + "net/netip" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" @@ -13,7 +15,7 @@ type MobileDependency struct { TunAdapter device.TunAdapter IFaceDiscover stdnet.ExternalIFaceDiscover NetworkChangeListener listener.NetworkChangeListener - HostDNSAddresses []string + HostDNSAddresses []netip.AddrPort DnsReadyListener dns.ReadyListener // iOS only diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index bd4e63357..239cce7e0 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -140,7 +140,7 @@ type RosenpassState struct { // whether it's enabled, and the last error message encountered during probing. type NSGroupState struct { ID string - Servers []string + Servers []netip.AddrPort Domains []string Enabled bool Error error diff --git a/client/server/server.go b/client/server/server.go index fadf2e7fc..daef7d02b 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1197,8 +1197,14 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { if dnsState.Error != nil { err = dnsState.Error.Error() } + + var servers []string + for _, server := range dnsState.Servers { + servers = append(servers, server.String()) + } + pbDnsState := &proto.NSGroupState{ - Servers: dnsState.Servers, + Servers: servers, Domains: dnsState.Domains, Enabled: dnsState.Enabled, Error: err, diff --git a/dns/nameserver.go b/dns/nameserver.go index bb904b165..81c616c50 100644 --- a/dns/nameserver.go +++ b/dns/nameserver.go @@ -102,6 +102,11 @@ func (n *NameServer) IsEqual(other *NameServer) bool { other.Port == n.Port } +// AddrPort returns the nameserver as a netip.AddrPort +func (n *NameServer) AddrPort() netip.AddrPort { + return netip.AddrPortFrom(n.IP, uint16(n.Port)) +} + // ParseNameServerURL parses a nameserver url in the format ://:, e.g., udp://1.1.1.1:53 func ParseNameServerURL(nsURL string) (NameServer, error) { parsedURL, err := url.Parse(nsURL)