[client] Eliminate upstream server strings in dns code (#4267)

This commit is contained in:
Viktor Liu
2025-08-11 11:57:21 +02:00
committed by GitHub
parent 375fcf2752
commit 1022a5015c
30 changed files with 161 additions and 308 deletions

View File

@ -4,6 +4,7 @@ package android
import (
"context"
"slices"
"sync"
log "github.com/sirupsen/logrus"
@ -112,7 +113,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@ -138,7 +139,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
}
// Stop the internal client and free the resources
@ -235,7 +236,7 @@ func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
return err
}
dnsServer.OnUpdatedHostDNSServer(list.items)
dnsServer.OnUpdatedHostDNSServer(slices.Clone(list.items))
return nil
}

View File

@ -1,23 +1,34 @@
package android
import "fmt"
import (
"fmt"
"net/netip"
// DNSList is a wrapper of []string
"github.com/netbirdio/netbird/client/internal/dns"
)
// DNSList is a wrapper of []netip.AddrPort with default DNS port
type DNSList struct {
items []string
items []netip.AddrPort
}
// Add new DNS address to the collection
func (array *DNSList) Add(s string) {
array.items = append(array.items, s)
// Add new DNS address to the collection, returns error if invalid
func (array *DNSList) Add(s string) error {
addr, err := netip.ParseAddr(s)
if err != nil {
return fmt.Errorf("invalid DNS address: %s", s)
}
addrPort := netip.AddrPortFrom(addr.Unmap(), dns.DefaultPort)
array.items = append(array.items, addrPort)
return nil
}
// Get return an element of the collection
// Get return an element of the collection as string
func (array *DNSList) Get(i int) (string, error) {
if i >= len(array.items) || i < 0 {
return "", fmt.Errorf("out of range")
}
return array.items[i], nil
return array.items[i].Addr().String(), nil
}
// Size return with the size of the collection

View File

@ -3,20 +3,30 @@ package android
import "testing"
func TestDNSList_Get(t *testing.T) {
l := DNSList{
items: make([]string, 1),
l := DNSList{}
// Add a valid DNS address
err := l.Add("8.8.8.8")
if err != nil {
t.Errorf("unexpected error: %s", err)
}
_, err := l.Get(0)
// Test getting valid index
addr, err := l.Get(0)
if err != nil {
t.Errorf("invalid error: %s", err)
}
if addr != "8.8.8.8" {
t.Errorf("expected 8.8.8.8, got %s", addr)
}
// Test negative index
_, err = l.Get(-1)
if err == nil {
t.Errorf("expected error but got nil")
}
// Test out of bounds index
_, err = l.Get(1)
if err == nil {
t.Errorf("expected error but got nil")

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"runtime/debug"
"strings"
@ -70,7 +71,7 @@ func (c *ConnectClient) RunOnAndroid(
tunAdapter device.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []string,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
) error {
// in case of non Android os these variables will be nil

View File

@ -16,7 +16,7 @@ const (
)
type resolvConf struct {
nameServers []string
nameServers []netip.Addr
searchDomains []string
others []string
}
@ -36,7 +36,7 @@ func parseBackupResolvConf() (*resolvConf, error) {
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
rconf := &resolvConf{
searchDomains: make([]string, 0),
nameServers: make([]string, 0),
nameServers: make([]netip.Addr, 0),
others: make([]string, 0),
}
@ -94,7 +94,11 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
if len(splitLines) != 2 {
continue
}
rconf.nameServers = append(rconf.nameServers, splitLines[1])
if addr, err := netip.ParseAddr(splitLines[1]); err == nil {
rconf.nameServers = append(rconf.nameServers, addr.Unmap())
} else {
log.Warnf("invalid nameserver address in resolv.conf: %s, skipping", splitLines[1])
}
continue
}
@ -104,31 +108,3 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
}
return rconf, nil
}
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
// and writes the file back to the original location
func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error {
resolvConf, err := parseResolvConfFile(filename)
if err != nil {
return fmt.Errorf("parse backup resolv.conf: %w", err)
}
content, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("read %s: %w", filename, err)
}
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() {
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
stat, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("stat %s: %w", filename, err)
}
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
return fmt.Errorf("write %s: %w", filename, err)
}
}
return nil
}

View File

@ -3,13 +3,9 @@
package dns
import (
"net/netip"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_parseResolvConf(t *testing.T) {
@ -99,9 +95,13 @@ options debug
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
}
ok = compareLists(cfg.nameServers, testCase.expectedNS)
nsStrings := make([]string, len(cfg.nameServers))
for i, ns := range cfg.nameServers {
nsStrings[i] = ns.String()
}
ok = compareLists(nsStrings, testCase.expectedNS)
if !ok {
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers)
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, nsStrings)
}
ok = compareLists(cfg.others, testCase.expectedOther)
@ -177,86 +177,3 @@ nameserver 192.168.0.1
}
}
func TestRemoveFirstNbNameserver(t *testing.T) {
testCases := []struct {
name string
content string
ipToRemove string
expected string
}{
{
name: "Unrelated nameservers with comments and options",
content: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
ipToRemove: "9.9.9.9",
expected: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
},
{
name: "First nameserver matches",
content: `search example.com
nameserver 9.9.9.9
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
ipToRemove: "9.9.9.9",
expected: `search example.com
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
},
{
name: "Target IP not the first nameserver",
// nolint:dupword
content: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
ipToRemove: "9.9.9.9",
// nolint:dupword
expected: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
},
{
name: "Only nameserver matches",
content: `options debug
nameserver 9.9.9.9
search localdomain`,
ipToRemove: "9.9.9.9",
expected: `options debug
nameserver 9.9.9.9
search localdomain`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "resolv.conf")
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
assert.NoError(t, err)
ip, err := netip.ParseAddr(tc.ipToRemove)
require.NoError(t, err, "Failed to parse IP address")
err = removeFirstNbNameserver(tempFile, ip)
assert.NoError(t, err)
content, err := os.ReadFile(tempFile)
assert.NoError(t, err)
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
})
}
}

View File

@ -146,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rCon
return true
}
if rConf.nameServers[0] != nbNameserverIP.String() {
if rConf.nameServers[0] != nbNameserverIP {
return true
}

View File

@ -29,7 +29,7 @@ type fileConfigurator struct {
repair *repair
originalPerms os.FileMode
nbNameserverIP netip.Addr
originalNameservers []string
originalNameservers []netip.Addr
}
func newFileConfigurator() (*fileConfigurator, error) {
@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
}
// getOriginalNameservers returns the nameservers that were found in the original resolv.conf
func (f *fileConfigurator) getOriginalNameservers() []string {
func (f *fileConfigurator) getOriginalNameservers() []netip.Addr {
return f.originalNameservers
}
@ -128,20 +128,14 @@ func (f *fileConfigurator) backup() error {
}
func (f *fileConfigurator) restore() error {
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
if err != nil {
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
}
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
if err != nil {
if err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath); err != nil {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
}
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress netip.Addr) error {
resolvConf, err := parseDefaultResolvConf()
if err != nil {
return fmt.Errorf("parse current resolv.conf: %w", err)
@ -152,16 +146,9 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return restoreResolvConfFile()
}
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
// not a valid first nameserver -> restore
if err != nil {
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
return restoreResolvConfFile()
}
// current address is still netbird's non-available dns address -> restore
// comparing parsed addresses only, to remove ambiguity
if currentDNSAddress.String() == storedDNSAddress.String() {
currentDNSAddress := resolvConf.nameServers[0]
if currentDNSAddress == storedDNSAddress {
return restoreResolvConfFile()
}

View File

@ -239,7 +239,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip
dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
}
}
@ -250,7 +250,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
}
// default to 53 port
dnsSettings.ServerPort = defaultPort
dnsSettings.ServerPort = DefaultPort
return dnsSettings, nil
}

View File

@ -42,7 +42,7 @@ func (t osManagerType) String() string {
type restoreHostManager interface {
hostManager
restoreUncleanShutdownDNS(*netip.Addr) error
restoreUncleanShutdownDNS(netip.Addr) error
}
func newHostManager(wgInterface string) (hostManager, error) {
@ -130,8 +130,9 @@ func checkStub() bool {
return true
}
systemdResolvedAddr := netip.AddrFrom4([4]byte{127, 0, 0, 53}) // 127.0.0.53
for _, ns := range rConf.nameServers {
if ns == "127.0.0.53" {
if ns == systemdResolvedAddr {
return true
}
}

View File

@ -216,7 +216,7 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
return fmt.Errorf("adding dns setup for all failed: %w", err)
}
r.routingAll = true
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
log.Infof("configured %s:%d as main DNS forwarder for this peer", ip, DefaultPort)
return nil
}

View File

@ -1,38 +1,31 @@
package dns
import (
"fmt"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
)
type hostsDNSHolder struct {
unprotectedDNSList map[string]struct{}
unprotectedDNSList map[netip.AddrPort]struct{}
mutex sync.RWMutex
}
func newHostsDNSHolder() *hostsDNSHolder {
return &hostsDNSHolder{
unprotectedDNSList: make(map[string]struct{}),
unprotectedDNSList: make(map[netip.AddrPort]struct{}),
}
}
func (h *hostsDNSHolder) set(list []string) {
func (h *hostsDNSHolder) set(list []netip.AddrPort) {
h.mutex.Lock()
h.unprotectedDNSList = make(map[string]struct{})
for _, dns := range list {
dnsAddr, err := h.normalizeAddress(dns)
if err != nil {
continue
}
h.unprotectedDNSList[dnsAddr] = struct{}{}
h.unprotectedDNSList = make(map[netip.AddrPort]struct{})
for _, addrPort := range list {
h.unprotectedDNSList[addrPort] = struct{}{}
}
h.mutex.Unlock()
}
func (h *hostsDNSHolder) get() map[string]struct{} {
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
h.mutex.RLock()
l := h.unprotectedDNSList
h.mutex.RUnlock()
@ -40,24 +33,10 @@ func (h *hostsDNSHolder) get() map[string]struct{} {
}
//nolint:unused
func (h *hostsDNSHolder) isContain(upstream string) bool {
func (h *hostsDNSHolder) contains(upstream netip.AddrPort) bool {
h.mutex.RLock()
defer h.mutex.RUnlock()
_, ok := h.unprotectedDNSList[upstream]
return ok
}
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
a, err := netip.ParseAddr(addr)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
return "", err
}
if a.Is4() {
return fmt.Sprintf("%s:53", addr), nil
} else {
return fmt.Sprintf("[%s]:53", addr), nil
}
}

View File

@ -50,7 +50,7 @@ func (m *MockServer) DnsIP() netip.Addr {
return netip.MustParseAddr("100.10.254.255")
}
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
func (m *MockServer) OnUpdatedHostDNSServer(addrs []netip.AddrPort) {
// TODO implement me
panic("implement me")
}

View File

@ -245,7 +245,7 @@ func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
return nil
}
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
if err := n.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via network-manager: %w", err)
}

View File

@ -40,7 +40,7 @@ type resolvconf struct {
implType resolvconfType
originalSearchDomains []string
originalNameServers []string
originalNameServers []netip.Addr
othersConfigs []string
}
@ -110,7 +110,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
return nil
}
func (r *resolvconf) getOriginalNameservers() []string {
func (r *resolvconf) getOriginalNameservers() []netip.Addr {
return r.originalNameServers
}
@ -158,7 +158,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
return nil
}
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error {
func (r *resolvconf) restoreUncleanShutdownDNS(netip.Addr) error {
if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
}

View File

@ -42,7 +42,7 @@ type Server interface {
Stop()
DnsIP() netip.Addr
UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string)
OnUpdatedHostDNSServer(addrs []netip.AddrPort)
SearchDomains() []string
ProbeAvailability()
}
@ -55,7 +55,7 @@ type nsGroupsByDomain struct {
// hostManagerWithOriginalNS extends the basic hostManager interface
type hostManagerWithOriginalNS interface {
hostManager
getOriginalNameservers() []string
getOriginalNameservers() []netip.Addr
}
// DefaultServer dns server object
@ -136,7 +136,7 @@ func NewDefaultServer(
func NewDefaultServerPermanentUpstream(
ctx context.Context,
wgInterface WGIface,
hostsDnsList []string,
hostsDnsList []netip.AddrPort,
config nbdns.Config,
listener listener.NetworkChangeListener,
statusRecorder *peer.Status,
@ -144,6 +144,7 @@ func NewDefaultServerPermanentUpstream(
) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
@ -340,7 +341,7 @@ func (s *DefaultServer) disableDNS() error {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) {
s.hostsDNSHolder.set(hostsDnsList)
// Check if there's any root handler
@ -461,7 +462,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
if s.service.RuntimePort() != DefaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
s.currentConfig.RouteAll = false
@ -581,14 +582,13 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
}
for _, ns := range originalNameservers {
if ns == config.ServerIP.String() {
if ns == config.ServerIP {
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
continue
}
ns = formatAddr(ns, defaultPort)
handler.upstreamServers = append(handler.upstreamServers, ns)
addrPort := netip.AddrPortFrom(ns, DefaultPort)
handler.upstreamServers = append(handler.upstreamServers, addrPort)
}
handler.deactivate = func(error) { /* always active */ }
handler.reactivate = func() { /* always active */ }
@ -695,7 +695,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue
}
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
}
if len(handler.upstreamServers) == 0 {
@ -770,18 +770,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
s.dnsMuxMap = muxUpdateMap
}
func getNSHostPort(ns nbdns.NameServer) string {
return formatAddr(ns.IP.String(), ns.Port)
}
// formatAddr formats a nameserver address with port, handling IPv6 addresses properly
func formatAddr(address string, port int) string {
if ip, err := netip.ParseAddr(address); err == nil && ip.Is6() {
return fmt.Sprintf("[%s]:%d", address, port)
}
return fmt.Sprintf("%s:%d", address, port)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate.
@ -879,10 +867,7 @@ func (s *DefaultServer) addHostRootZone() {
return
}
handler.upstreamServers = make([]string, 0)
for k := range hostDNSServers {
handler.upstreamServers = append(handler.upstreamServers, k)
}
handler.upstreamServers = maps.Keys(hostDNSServers)
handler.deactivate = func(error) {}
handler.reactivate = func() {}
@ -893,9 +878,9 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
var states []peer.NSGroupState
for _, group := range groups {
var servers []string
var servers []netip.AddrPort
for _, ns := range group.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
servers = append(servers, ns.AddrPort())
}
state := peer.NSGroupState{
@ -927,7 +912,7 @@ func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error,
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
var servers []string
for _, ns := range nsGroup.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
servers = append(servers, ns.AddrPort().String())
}
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
}

View File

@ -97,9 +97,9 @@ func init() {
}
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []string
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, getNSHostPort(srv))
srvs = append(srvs, srv.AddrPort())
}
return &upstreamResolverBase{
domain: domain,
@ -705,7 +705,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
}
defer wgIFace.Close()
var dnsList []string
var dnsList []netip.AddrPort
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize()
@ -715,7 +715,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
}
defer dnsServer.Stop()
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer.OnUpdatedHostDNSServer([]netip.AddrPort{addrPort})
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
@ -731,7 +732,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@ -823,7 +825,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@ -2053,56 +2056,3 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}
func TestFormatAddr(t *testing.T) {
tests := []struct {
name string
address string
port int
expected string
}{
{
name: "IPv4 address",
address: "8.8.8.8",
port: 53,
expected: "8.8.8.8:53",
},
{
name: "IPv4 address with custom port",
address: "1.1.1.1",
port: 5353,
expected: "1.1.1.1:5353",
},
{
name: "IPv6 address",
address: "fd78:94bf:7df8::1",
port: 53,
expected: "[fd78:94bf:7df8::1]:53",
},
{
name: "IPv6 address with custom port",
address: "2001:db8::1",
port: 5353,
expected: "[2001:db8::1]:5353",
},
{
name: "IPv6 localhost",
address: "::1",
port: 53,
expected: "[::1]:53",
},
{
name: "Invalid address treated as hostname",
address: "dns.example.com",
port: 53,
expected: "dns.example.com:53",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatAddr(tt.address, tt.port)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@ -7,7 +7,7 @@ import (
)
const (
defaultPort = 53
DefaultPort = 53
)
type service interface {

View File

@ -122,7 +122,7 @@ func (s *serviceViaListener) RuntimePort() int {
defer s.listenerFlagLock.Unlock()
if s.ebpfService != nil {
return defaultPort
return DefaultPort
} else {
return int(s.listenPort)
}
@ -148,9 +148,9 @@ func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
return s.customAddr.Addr(), s.customAddr.Port(), nil
}
ip, ok := s.testFreePort(defaultPort)
ip, ok := s.testFreePort(DefaultPort)
if ok {
return ip, defaultPort, nil
return ip, DefaultPort, nil
}
ebpfSrv, port, ok := s.tryToUseeBPF()

View File

@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
dnsMux: dns.NewServeMux(),
runtimeIP: lastIP,
runtimePort: defaultPort,
runtimePort: DefaultPort,
}
return s
}

View File

@ -235,7 +235,7 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
return nil
}
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via systemd: %w", err)
}

View File

@ -27,7 +27,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create previous host manager: %w", err)
}
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil {
if err := manager.restoreUncleanShutdownDNS(s.DNSAddress); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err)
}

View File

@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
@ -48,7 +49,7 @@ type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []string
upstreamServers []netip.AddrPort
domain string
disabled bool
failsCount atomic.Int32
@ -79,17 +80,20 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %v", u.upstreamServers)
return fmt.Sprintf("upstream %s", u.upstreamServers)
}
// ID returns the unique handler ID
func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers)
slices.Sort(servers)
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ",")))
for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
@ -130,7 +134,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
defer cancel()
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
if err != nil {
@ -197,7 +201,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
proto.SystemEvent_DNS,
"All upstream servers failed (fail count exceeded)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
map[string]string{"upstreams": u.upstreamServersString()},
// TODO add domain meta
)
}
@ -258,7 +262,7 @@ func (u *upstreamResolverBase) ProbeAvailability() {
proto.SystemEvent_DNS,
"All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
map[string]string{"upstreams": u.upstreamServersString()},
)
}
}
@ -278,7 +282,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
operation := func() error {
select {
case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers))
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
default:
}
@ -291,7 +295,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
}
}
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff())
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
return fmt.Errorf("upstream check call error")
}
@ -301,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return
}
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.failsCount.Store(0)
u.successCount.Add(1)
u.reactivate()
@ -331,13 +335,21 @@ func (u *upstreamResolverBase) disable(err error) {
go u.waitUntilResponse()
}
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
func (u *upstreamResolverBase) upstreamServersString() string {
var servers []string
for _, server := range u.upstreamServers {
servers = append(servers, server.String())
}
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(ctx, server, r)
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
return err
}

View File

@ -79,8 +79,8 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
}
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
if u.hostsDNSHolder.isContain(upstream) {
return true
if addrPort, err := netip.ParseAddrPort(upstream); err == nil {
return u.hostsDNSHolder.contains(addrPort)
}
return false
}

View File

@ -62,6 +62,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP, err := netip.ParseAddr(upstreamHost)
if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
} else {
upstreamIP = upstreamIP.Unmap()
}
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream)

View File

@ -59,7 +59,14 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers
// Convert test servers to netip.AddrPort
var servers []netip.AddrPort
for _, server := range testCase.InputServers {
if addrPort, err := netip.ParseAddrPort(server); err == nil {
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
}
}
resolver.upstreamServers = servers
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
cancel()
@ -128,7 +135,8 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
}
resolver.upstreamServers = []string{"0.0.0.0:-1"}
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100

View File

@ -1,6 +1,8 @@
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
@ -13,7 +15,7 @@ type MobileDependency struct {
TunAdapter device.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
NetworkChangeListener listener.NetworkChangeListener
HostDNSAddresses []string
HostDNSAddresses []netip.AddrPort
DnsReadyListener dns.ReadyListener
// iOS only

View File

@ -140,7 +140,7 @@ type RosenpassState struct {
// whether it's enabled, and the last error message encountered during probing.
type NSGroupState struct {
ID string
Servers []string
Servers []netip.AddrPort
Domains []string
Enabled bool
Error error

View File

@ -1197,8 +1197,14 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: dnsState.Servers,
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,

View File

@ -102,6 +102,11 @@ func (n *NameServer) IsEqual(other *NameServer) bool {
other.Port == n.Port
}
// AddrPort returns the nameserver as a netip.AddrPort
func (n *NameServer) AddrPort() netip.AddrPort {
return netip.AddrPortFrom(n.IP, uint16(n.Port))
}
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
func ParseNameServerURL(nsURL string) (NameServer, error) {
parsedURL, err := url.Parse(nsURL)