mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-17 18:41:41 +02:00
[client] Eliminate upstream server strings in dns code (#4267)
This commit is contained in:
@ -4,6 +4,7 @@ package android
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -112,7 +113,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@ -138,7 +139,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@ -235,7 +236,7 @@ func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.OnUpdatedHostDNSServer(list.items)
|
dnsServer.OnUpdatedHostDNSServer(slices.Clone(list.items))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,23 +1,34 @@
|
|||||||
package android
|
package android
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
// DNSList is a wrapper of []string
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DNSList is a wrapper of []netip.AddrPort with default DNS port
|
||||||
type DNSList struct {
|
type DNSList struct {
|
||||||
items []string
|
items []netip.AddrPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new DNS address to the collection
|
// Add new DNS address to the collection, returns error if invalid
|
||||||
func (array *DNSList) Add(s string) {
|
func (array *DNSList) Add(s string) error {
|
||||||
array.items = append(array.items, s)
|
addr, err := netip.ParseAddr(s)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid DNS address: %s", s)
|
||||||
|
}
|
||||||
|
addrPort := netip.AddrPortFrom(addr.Unmap(), dns.DefaultPort)
|
||||||
|
array.items = append(array.items, addrPort)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get return an element of the collection
|
// Get return an element of the collection as string
|
||||||
func (array *DNSList) Get(i int) (string, error) {
|
func (array *DNSList) Get(i int) (string, error) {
|
||||||
if i >= len(array.items) || i < 0 {
|
if i >= len(array.items) || i < 0 {
|
||||||
return "", fmt.Errorf("out of range")
|
return "", fmt.Errorf("out of range")
|
||||||
}
|
}
|
||||||
return array.items[i], nil
|
return array.items[i].Addr().String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size return with the size of the collection
|
// Size return with the size of the collection
|
||||||
|
@ -3,20 +3,30 @@ package android
|
|||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
func TestDNSList_Get(t *testing.T) {
|
func TestDNSList_Get(t *testing.T) {
|
||||||
l := DNSList{
|
l := DNSList{}
|
||||||
items: make([]string, 1),
|
|
||||||
|
// Add a valid DNS address
|
||||||
|
err := l.Add("8.8.8.8")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := l.Get(0)
|
// Test getting valid index
|
||||||
|
addr, err := l.Get(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("invalid error: %s", err)
|
t.Errorf("invalid error: %s", err)
|
||||||
}
|
}
|
||||||
|
if addr != "8.8.8.8" {
|
||||||
|
t.Errorf("expected 8.8.8.8, got %s", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test negative index
|
||||||
_, err = l.Get(-1)
|
_, err = l.Get(-1)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error but got nil")
|
t.Errorf("expected error but got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test out of bounds index
|
||||||
_, err = l.Get(1)
|
_, err = l.Get(1)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected error but got nil")
|
t.Errorf("expected error but got nil")
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
@ -70,7 +71,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
tunAdapter device.TunAdapter,
|
tunAdapter device.TunAdapter,
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsAddresses []string,
|
dnsAddresses []netip.AddrPort,
|
||||||
dnsReadyListener dns.ReadyListener,
|
dnsReadyListener dns.ReadyListener,
|
||||||
) error {
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
|
@ -16,7 +16,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type resolvConf struct {
|
type resolvConf struct {
|
||||||
nameServers []string
|
nameServers []netip.Addr
|
||||||
searchDomains []string
|
searchDomains []string
|
||||||
others []string
|
others []string
|
||||||
}
|
}
|
||||||
@ -36,7 +36,7 @@ func parseBackupResolvConf() (*resolvConf, error) {
|
|||||||
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
||||||
rconf := &resolvConf{
|
rconf := &resolvConf{
|
||||||
searchDomains: make([]string, 0),
|
searchDomains: make([]string, 0),
|
||||||
nameServers: make([]string, 0),
|
nameServers: make([]netip.Addr, 0),
|
||||||
others: make([]string, 0),
|
others: make([]string, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,7 +94,11 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
|||||||
if len(splitLines) != 2 {
|
if len(splitLines) != 2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
rconf.nameServers = append(rconf.nameServers, splitLines[1])
|
if addr, err := netip.ParseAddr(splitLines[1]); err == nil {
|
||||||
|
rconf.nameServers = append(rconf.nameServers, addr.Unmap())
|
||||||
|
} else {
|
||||||
|
log.Warnf("invalid nameserver address in resolv.conf: %s, skipping", splitLines[1])
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,31 +108,3 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
|
|||||||
}
|
}
|
||||||
return rconf, nil
|
return rconf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
|
|
||||||
// and writes the file back to the original location
|
|
||||||
func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error {
|
|
||||||
resolvConf, err := parseResolvConfFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse backup resolv.conf: %w", err)
|
|
||||||
}
|
|
||||||
content, err := os.ReadFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read %s: %w", filename, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() {
|
|
||||||
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
|
|
||||||
|
|
||||||
stat, err := os.Stat(filename)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("stat %s: %w", filename, err)
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
|
|
||||||
return fmt.Errorf("write %s: %w", filename, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -3,13 +3,9 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_parseResolvConf(t *testing.T) {
|
func Test_parseResolvConf(t *testing.T) {
|
||||||
@ -99,9 +95,13 @@ options debug
|
|||||||
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
|
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok = compareLists(cfg.nameServers, testCase.expectedNS)
|
nsStrings := make([]string, len(cfg.nameServers))
|
||||||
|
for i, ns := range cfg.nameServers {
|
||||||
|
nsStrings[i] = ns.String()
|
||||||
|
}
|
||||||
|
ok = compareLists(nsStrings, testCase.expectedNS)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers)
|
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, nsStrings)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok = compareLists(cfg.others, testCase.expectedOther)
|
ok = compareLists(cfg.others, testCase.expectedOther)
|
||||||
@ -177,86 +177,3 @@ nameserver 192.168.0.1
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveFirstNbNameserver(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
content string
|
|
||||||
ipToRemove string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Unrelated nameservers with comments and options",
|
|
||||||
content: `# This is a comment
|
|
||||||
options rotate
|
|
||||||
nameserver 1.1.1.1
|
|
||||||
# Another comment
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
search example.com`,
|
|
||||||
ipToRemove: "9.9.9.9",
|
|
||||||
expected: `# This is a comment
|
|
||||||
options rotate
|
|
||||||
nameserver 1.1.1.1
|
|
||||||
# Another comment
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
search example.com`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "First nameserver matches",
|
|
||||||
content: `search example.com
|
|
||||||
nameserver 9.9.9.9
|
|
||||||
# oof, a comment
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
options attempts:5`,
|
|
||||||
ipToRemove: "9.9.9.9",
|
|
||||||
expected: `search example.com
|
|
||||||
# oof, a comment
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
options attempts:5`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Target IP not the first nameserver",
|
|
||||||
// nolint:dupword
|
|
||||||
content: `# Comment about the first nameserver
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
# Comment before our target
|
|
||||||
nameserver 9.9.9.9
|
|
||||||
options timeout:2`,
|
|
||||||
ipToRemove: "9.9.9.9",
|
|
||||||
// nolint:dupword
|
|
||||||
expected: `# Comment about the first nameserver
|
|
||||||
nameserver 8.8.4.4
|
|
||||||
# Comment before our target
|
|
||||||
nameserver 9.9.9.9
|
|
||||||
options timeout:2`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Only nameserver matches",
|
|
||||||
content: `options debug
|
|
||||||
nameserver 9.9.9.9
|
|
||||||
search localdomain`,
|
|
||||||
ipToRemove: "9.9.9.9",
|
|
||||||
expected: `options debug
|
|
||||||
nameserver 9.9.9.9
|
|
||||||
search localdomain`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
tempFile := filepath.Join(tempDir, "resolv.conf")
|
|
||||||
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
ip, err := netip.ParseAddr(tc.ipToRemove)
|
|
||||||
require.NoError(t, err, "Failed to parse IP address")
|
|
||||||
err = removeFirstNbNameserver(tempFile, ip)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
content, err := os.ReadFile(tempFile)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -146,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rCon
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if rConf.nameServers[0] != nbNameserverIP.String() {
|
if rConf.nameServers[0] != nbNameserverIP {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ type fileConfigurator struct {
|
|||||||
repair *repair
|
repair *repair
|
||||||
originalPerms os.FileMode
|
originalPerms os.FileMode
|
||||||
nbNameserverIP netip.Addr
|
nbNameserverIP netip.Addr
|
||||||
originalNameservers []string
|
originalNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFileConfigurator() (*fileConfigurator, error) {
|
func newFileConfigurator() (*fileConfigurator, error) {
|
||||||
@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getOriginalNameservers returns the nameservers that were found in the original resolv.conf
|
// getOriginalNameservers returns the nameservers that were found in the original resolv.conf
|
||||||
func (f *fileConfigurator) getOriginalNameservers() []string {
|
func (f *fileConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
return f.originalNameservers
|
return f.originalNameservers
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,20 +128,14 @@ func (f *fileConfigurator) backup() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) restore() error {
|
func (f *fileConfigurator) restore() error {
|
||||||
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
|
if err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
|
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress netip.Addr) error {
|
||||||
resolvConf, err := parseDefaultResolvConf()
|
resolvConf, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse current resolv.conf: %w", err)
|
return fmt.Errorf("parse current resolv.conf: %w", err)
|
||||||
@ -152,16 +146,9 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
|
|||||||
return restoreResolvConfFile()
|
return restoreResolvConfFile()
|
||||||
}
|
}
|
||||||
|
|
||||||
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
|
|
||||||
// not a valid first nameserver -> restore
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
|
|
||||||
return restoreResolvConfFile()
|
|
||||||
}
|
|
||||||
|
|
||||||
// current address is still netbird's non-available dns address -> restore
|
// current address is still netbird's non-available dns address -> restore
|
||||||
// comparing parsed addresses only, to remove ambiguity
|
currentDNSAddress := resolvConf.nameServers[0]
|
||||||
if currentDNSAddress.String() == storedDNSAddress.String() {
|
if currentDNSAddress == storedDNSAddress {
|
||||||
return restoreResolvConfFile()
|
return restoreResolvConfFile()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -239,7 +239,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
} else if inServerAddressesArray {
|
} else if inServerAddressesArray {
|
||||||
address := strings.Split(line, " : ")[1]
|
address := strings.Split(line, " : ")[1]
|
||||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
||||||
dnsSettings.ServerIP = ip
|
dnsSettings.ServerIP = ip.Unmap()
|
||||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -250,7 +250,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// default to 53 port
|
// default to 53 port
|
||||||
dnsSettings.ServerPort = defaultPort
|
dnsSettings.ServerPort = DefaultPort
|
||||||
|
|
||||||
return dnsSettings, nil
|
return dnsSettings, nil
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ func (t osManagerType) String() string {
|
|||||||
|
|
||||||
type restoreHostManager interface {
|
type restoreHostManager interface {
|
||||||
hostManager
|
hostManager
|
||||||
restoreUncleanShutdownDNS(*netip.Addr) error
|
restoreUncleanShutdownDNS(netip.Addr) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
@ -130,8 +130,9 @@ func checkStub() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
systemdResolvedAddr := netip.AddrFrom4([4]byte{127, 0, 0, 53}) // 127.0.0.53
|
||||||
for _, ns := range rConf.nameServers {
|
for _, ns := range rConf.nameServers {
|
||||||
if ns == "127.0.0.53" {
|
if ns == systemdResolvedAddr {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -216,7 +216,7 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
|||||||
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||||
}
|
}
|
||||||
r.routingAll = true
|
r.routingAll = true
|
||||||
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
|
log.Infof("configured %s:%d as main DNS forwarder for this peer", ip, DefaultPort)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,38 +1,31 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type hostsDNSHolder struct {
|
type hostsDNSHolder struct {
|
||||||
unprotectedDNSList map[string]struct{}
|
unprotectedDNSList map[netip.AddrPort]struct{}
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostsDNSHolder() *hostsDNSHolder {
|
func newHostsDNSHolder() *hostsDNSHolder {
|
||||||
return &hostsDNSHolder{
|
return &hostsDNSHolder{
|
||||||
unprotectedDNSList: make(map[string]struct{}),
|
unprotectedDNSList: make(map[netip.AddrPort]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hostsDNSHolder) set(list []string) {
|
func (h *hostsDNSHolder) set(list []netip.AddrPort) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
h.unprotectedDNSList = make(map[string]struct{})
|
h.unprotectedDNSList = make(map[netip.AddrPort]struct{})
|
||||||
for _, dns := range list {
|
for _, addrPort := range list {
|
||||||
dnsAddr, err := h.normalizeAddress(dns)
|
h.unprotectedDNSList[addrPort] = struct{}{}
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
h.unprotectedDNSList[dnsAddr] = struct{}{}
|
|
||||||
}
|
}
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hostsDNSHolder) get() map[string]struct{} {
|
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
|
||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
l := h.unprotectedDNSList
|
l := h.unprotectedDNSList
|
||||||
h.mutex.RUnlock()
|
h.mutex.RUnlock()
|
||||||
@ -40,24 +33,10 @@ func (h *hostsDNSHolder) get() map[string]struct{} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//nolint:unused
|
//nolint:unused
|
||||||
func (h *hostsDNSHolder) isContain(upstream string) bool {
|
func (h *hostsDNSHolder) contains(upstream netip.AddrPort) bool {
|
||||||
h.mutex.RLock()
|
h.mutex.RLock()
|
||||||
defer h.mutex.RUnlock()
|
defer h.mutex.RUnlock()
|
||||||
|
|
||||||
_, ok := h.unprotectedDNSList[upstream]
|
_, ok := h.unprotectedDNSList[upstream]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
|
|
||||||
a, err := netip.ParseAddr(addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.Is4() {
|
|
||||||
return fmt.Sprintf("%s:53", addr), nil
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("[%s]:53", addr), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -50,7 +50,7 @@ func (m *MockServer) DnsIP() netip.Addr {
|
|||||||
return netip.MustParseAddr("100.10.254.255")
|
return netip.MustParseAddr("100.10.254.255")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
func (m *MockServer) OnUpdatedHostDNSServer(addrs []netip.AddrPort) {
|
||||||
// TODO implement me
|
// TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
@ -245,7 +245,7 @@ func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
|
||||||
if err := n.restoreHostDNS(); err != nil {
|
if err := n.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via network-manager: %w", err)
|
return fmt.Errorf("restoring dns via network-manager: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ type resolvconf struct {
|
|||||||
implType resolvconfType
|
implType resolvconfType
|
||||||
|
|
||||||
originalSearchDomains []string
|
originalSearchDomains []string
|
||||||
originalNameServers []string
|
originalNameServers []netip.Addr
|
||||||
othersConfigs []string
|
othersConfigs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) getOriginalNameservers() []string {
|
func (r *resolvconf) getOriginalNameservers() []netip.Addr {
|
||||||
return r.originalNameServers
|
return r.originalNameServers
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error {
|
func (r *resolvconf) restoreUncleanShutdownDNS(netip.Addr) error {
|
||||||
if err := r.restoreHostDNS(); err != nil {
|
if err := r.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
|
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ type Server interface {
|
|||||||
Stop()
|
Stop()
|
||||||
DnsIP() netip.Addr
|
DnsIP() netip.Addr
|
||||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||||
OnUpdatedHostDNSServer(strings []string)
|
OnUpdatedHostDNSServer(addrs []netip.AddrPort)
|
||||||
SearchDomains() []string
|
SearchDomains() []string
|
||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
}
|
}
|
||||||
@ -55,7 +55,7 @@ type nsGroupsByDomain struct {
|
|||||||
// hostManagerWithOriginalNS extends the basic hostManager interface
|
// hostManagerWithOriginalNS extends the basic hostManager interface
|
||||||
type hostManagerWithOriginalNS interface {
|
type hostManagerWithOriginalNS interface {
|
||||||
hostManager
|
hostManager
|
||||||
getOriginalNameservers() []string
|
getOriginalNameservers() []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultServer dns server object
|
// DefaultServer dns server object
|
||||||
@ -136,7 +136,7 @@ func NewDefaultServer(
|
|||||||
func NewDefaultServerPermanentUpstream(
|
func NewDefaultServerPermanentUpstream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
wgInterface WGIface,
|
wgInterface WGIface,
|
||||||
hostsDnsList []string,
|
hostsDnsList []netip.AddrPort,
|
||||||
config nbdns.Config,
|
config nbdns.Config,
|
||||||
listener listener.NetworkChangeListener,
|
listener listener.NetworkChangeListener,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
@ -144,6 +144,7 @@ func NewDefaultServerPermanentUpstream(
|
|||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||||
|
|
||||||
ds.hostsDNSHolder.set(hostsDnsList)
|
ds.hostsDNSHolder.set(hostsDnsList)
|
||||||
ds.permanent = true
|
ds.permanent = true
|
||||||
ds.addHostRootZone()
|
ds.addHostRootZone()
|
||||||
@ -340,7 +341,7 @@ func (s *DefaultServer) disableDNS() error {
|
|||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) {
|
||||||
s.hostsDNSHolder.set(hostsDnsList)
|
s.hostsDNSHolder.set(hostsDnsList)
|
||||||
|
|
||||||
// Check if there's any root handler
|
// Check if there's any root handler
|
||||||
@ -461,7 +462,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
if s.service.RuntimePort() != DefaultPort && !s.hostManager.supportCustomPort() {
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
@ -581,14 +582,13 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, ns := range originalNameservers {
|
for _, ns := range originalNameservers {
|
||||||
if ns == config.ServerIP.String() {
|
if ns == config.ServerIP {
|
||||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
|
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ns = formatAddr(ns, defaultPort)
|
addrPort := netip.AddrPortFrom(ns, DefaultPort)
|
||||||
|
handler.upstreamServers = append(handler.upstreamServers, addrPort)
|
||||||
handler.upstreamServers = append(handler.upstreamServers, ns)
|
|
||||||
}
|
}
|
||||||
handler.deactivate = func(error) { /* always active */ }
|
handler.deactivate = func(error) { /* always active */ }
|
||||||
handler.reactivate = func() { /* always active */ }
|
handler.reactivate = func() { /* always active */ }
|
||||||
@ -695,7 +695,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
|||||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
|
handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(handler.upstreamServers) == 0 {
|
if len(handler.upstreamServers) == 0 {
|
||||||
@ -770,18 +770,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
|||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxMap = muxUpdateMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNSHostPort(ns nbdns.NameServer) string {
|
|
||||||
return formatAddr(ns.IP.String(), ns.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatAddr formats a nameserver address with port, handling IPv6 addresses properly
|
|
||||||
func formatAddr(address string, port int) string {
|
|
||||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is6() {
|
|
||||||
return fmt.Sprintf("[%s]:%d", address, port)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s:%d", address, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||||
// the upstream resolver from the configuration, the second one is used to
|
// the upstream resolver from the configuration, the second one is used to
|
||||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||||
@ -879,10 +867,7 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
handler.upstreamServers = make([]string, 0)
|
handler.upstreamServers = maps.Keys(hostDNSServers)
|
||||||
for k := range hostDNSServers {
|
|
||||||
handler.upstreamServers = append(handler.upstreamServers, k)
|
|
||||||
}
|
|
||||||
handler.deactivate = func(error) {}
|
handler.deactivate = func(error) {}
|
||||||
handler.reactivate = func() {}
|
handler.reactivate = func() {}
|
||||||
|
|
||||||
@ -893,9 +878,9 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
|
|||||||
var states []peer.NSGroupState
|
var states []peer.NSGroupState
|
||||||
|
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
var servers []string
|
var servers []netip.AddrPort
|
||||||
for _, ns := range group.NameServers {
|
for _, ns := range group.NameServers {
|
||||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
servers = append(servers, ns.AddrPort())
|
||||||
}
|
}
|
||||||
|
|
||||||
state := peer.NSGroupState{
|
state := peer.NSGroupState{
|
||||||
@ -927,7 +912,7 @@ func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error,
|
|||||||
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
|
||||||
var servers []string
|
var servers []string
|
||||||
for _, ns := range nsGroup.NameServers {
|
for _, ns := range nsGroup.NameServers {
|
||||||
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port))
|
servers = append(servers, ns.AddrPort().String())
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
|
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
|
||||||
}
|
}
|
||||||
|
@ -97,9 +97,9 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||||
var srvs []string
|
var srvs []netip.AddrPort
|
||||||
for _, srv := range servers {
|
for _, srv := range servers {
|
||||||
srvs = append(srvs, getNSHostPort(srv))
|
srvs = append(srvs, srv.AddrPort())
|
||||||
}
|
}
|
||||||
return &upstreamResolverBase{
|
return &upstreamResolverBase{
|
||||||
domain: domain,
|
domain: domain,
|
||||||
@ -705,7 +705,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
|
|
||||||
var dnsList []string
|
var dnsList []netip.AddrPort
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
@ -715,7 +715,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer dnsServer.Stop()
|
defer dnsServer.Stop()
|
||||||
|
|
||||||
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||||
|
dnsServer.OnUpdatedHostDNSServer([]netip.AddrPort{addrPort})
|
||||||
|
|
||||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||||
@ -731,7 +732,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||||
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@ -823,7 +825,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer wgIFace.Close()
|
defer wgIFace.Close()
|
||||||
dnsConfig := nbdns.Config{}
|
dnsConfig := nbdns.Config{}
|
||||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
addrPort := netip.MustParseAddrPort("8.8.8.8:53")
|
||||||
|
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
|
||||||
err = dnsServer.Initialize()
|
err = dnsServer.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to initialize DNS server: %v", err)
|
t.Errorf("failed to initialize DNS server: %v", err)
|
||||||
@ -2053,56 +2056,3 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
|||||||
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
||||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatAddr(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
address string
|
|
||||||
port int
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "IPv4 address",
|
|
||||||
address: "8.8.8.8",
|
|
||||||
port: 53,
|
|
||||||
expected: "8.8.8.8:53",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv4 address with custom port",
|
|
||||||
address: "1.1.1.1",
|
|
||||||
port: 5353,
|
|
||||||
expected: "1.1.1.1:5353",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 address",
|
|
||||||
address: "fd78:94bf:7df8::1",
|
|
||||||
port: 53,
|
|
||||||
expected: "[fd78:94bf:7df8::1]:53",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 address with custom port",
|
|
||||||
address: "2001:db8::1",
|
|
||||||
port: 5353,
|
|
||||||
expected: "[2001:db8::1]:5353",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 localhost",
|
|
||||||
address: "::1",
|
|
||||||
port: 53,
|
|
||||||
expected: "[::1]:53",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid address treated as hostname",
|
|
||||||
address: "dns.example.com",
|
|
||||||
port: 53,
|
|
||||||
expected: "dns.example.com:53",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := formatAddr(tt.address, tt.port)
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultPort = 53
|
DefaultPort = 53
|
||||||
)
|
)
|
||||||
|
|
||||||
type service interface {
|
type service interface {
|
||||||
|
@ -122,7 +122,7 @@ func (s *serviceViaListener) RuntimePort() int {
|
|||||||
defer s.listenerFlagLock.Unlock()
|
defer s.listenerFlagLock.Unlock()
|
||||||
|
|
||||||
if s.ebpfService != nil {
|
if s.ebpfService != nil {
|
||||||
return defaultPort
|
return DefaultPort
|
||||||
} else {
|
} else {
|
||||||
return int(s.listenPort)
|
return int(s.listenPort)
|
||||||
}
|
}
|
||||||
@ -148,9 +148,9 @@ func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
|
|||||||
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
return s.customAddr.Addr(), s.customAddr.Port(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, ok := s.testFreePort(defaultPort)
|
ip, ok := s.testFreePort(DefaultPort)
|
||||||
if ok {
|
if ok {
|
||||||
return ip, defaultPort, nil
|
return ip, DefaultPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ebpfSrv, port, ok := s.tryToUseeBPF()
|
ebpfSrv, port, ok := s.tryToUseeBPF()
|
||||||
|
@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
|||||||
dnsMux: dns.NewServeMux(),
|
dnsMux: dns.NewServeMux(),
|
||||||
|
|
||||||
runtimeIP: lastIP,
|
runtimeIP: lastIP,
|
||||||
runtimePort: defaultPort,
|
runtimePort: DefaultPort,
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -235,7 +235,7 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
|
||||||
if err := s.restoreHostDNS(); err != nil {
|
if err := s.restoreHostDNS(); err != nil {
|
||||||
return fmt.Errorf("restoring dns via systemd: %w", err)
|
return fmt.Errorf("restoring dns via systemd: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
return fmt.Errorf("create previous host manager: %w", err)
|
return fmt.Errorf("create previous host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(s.DNSAddress); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -48,7 +49,7 @@ type upstreamResolverBase struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
upstreamClient upstreamClient
|
upstreamClient upstreamClient
|
||||||
upstreamServers []string
|
upstreamServers []netip.AddrPort
|
||||||
domain string
|
domain string
|
||||||
disabled bool
|
disabled bool
|
||||||
failsCount atomic.Int32
|
failsCount atomic.Int32
|
||||||
@ -79,17 +80,20 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
|
|||||||
|
|
||||||
// String returns a string representation of the upstream resolver
|
// String returns a string representation of the upstream resolver
|
||||||
func (u *upstreamResolverBase) String() string {
|
func (u *upstreamResolverBase) String() string {
|
||||||
return fmt.Sprintf("upstream %v", u.upstreamServers)
|
return fmt.Sprintf("upstream %s", u.upstreamServers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
// ID returns the unique handler ID
|
||||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||||
servers := slices.Clone(u.upstreamServers)
|
servers := slices.Clone(u.upstreamServers)
|
||||||
slices.Sort(servers)
|
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
hash.Write([]byte(u.domain + ":"))
|
hash.Write([]byte(u.domain + ":"))
|
||||||
hash.Write([]byte(strings.Join(servers, ",")))
|
for _, s := range servers {
|
||||||
|
hash.Write([]byte(s.String()))
|
||||||
|
hash.Write([]byte("|"))
|
||||||
|
}
|
||||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,7 +134,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
func() {
|
func() {
|
||||||
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r)
|
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -197,7 +201,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
|||||||
proto.SystemEvent_DNS,
|
proto.SystemEvent_DNS,
|
||||||
"All upstream servers failed (fail count exceeded)",
|
"All upstream servers failed (fail count exceeded)",
|
||||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||||
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
|
map[string]string{"upstreams": u.upstreamServersString()},
|
||||||
// TODO add domain meta
|
// TODO add domain meta
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -258,7 +262,7 @@ func (u *upstreamResolverBase) ProbeAvailability() {
|
|||||||
proto.SystemEvent_DNS,
|
proto.SystemEvent_DNS,
|
||||||
"All upstream servers failed (probe failed)",
|
"All upstream servers failed (probe failed)",
|
||||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||||
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")},
|
map[string]string{"upstreams": u.upstreamServersString()},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -278,7 +282,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
operation := func() error {
|
operation := func() error {
|
||||||
select {
|
select {
|
||||||
case <-u.ctx.Done():
|
case <-u.ctx.Done():
|
||||||
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers))
|
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -291,7 +295,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff())
|
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
|
||||||
return fmt.Errorf("upstream check call error")
|
return fmt.Errorf("upstream check call error")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -301,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
|
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
||||||
u.failsCount.Store(0)
|
u.failsCount.Store(0)
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
u.reactivate()
|
u.reactivate()
|
||||||
@ -331,13 +335,21 @@ func (u *upstreamResolverBase) disable(err error) {
|
|||||||
go u.waitUntilResponse()
|
go u.waitUntilResponse()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
|
func (u *upstreamResolverBase) upstreamServersString() string {
|
||||||
|
var servers []string
|
||||||
|
for _, server := range u.upstreamServers {
|
||||||
|
servers = append(servers, server.String())
|
||||||
|
}
|
||||||
|
return strings.Join(servers, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
|
||||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||||
|
|
||||||
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,8 +79,8 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||||
if u.hostsDNSHolder.isContain(upstream) {
|
if addrPort, err := netip.ParseAddrPort(upstream); err == nil {
|
||||||
return true
|
return u.hostsDNSHolder.contains(addrPort)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -62,6 +62,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
upstreamIP, err := netip.ParseAddr(upstreamHost)
|
upstreamIP, err := netip.ParseAddr(upstreamHost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
|
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
|
||||||
|
} else {
|
||||||
|
upstreamIP = upstreamIP.Unmap()
|
||||||
}
|
}
|
||||||
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
||||||
log.Debugf("using private client to query upstream: %s", upstream)
|
log.Debugf("using private client to query upstream: %s", upstream)
|
||||||
|
@ -59,7 +59,14 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||||
resolver.upstreamServers = testCase.InputServers
|
// Convert test servers to netip.AddrPort
|
||||||
|
var servers []netip.AddrPort
|
||||||
|
for _, server := range testCase.InputServers {
|
||||||
|
if addrPort, err := netip.ParseAddrPort(server); err == nil {
|
||||||
|
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resolver.upstreamServers = servers
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
cancel()
|
cancel()
|
||||||
@ -128,7 +135,8 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
|||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
}
|
}
|
||||||
resolver.upstreamServers = []string{"0.0.0.0:-1"}
|
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
|
||||||
|
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
|
||||||
resolver.failsTillDeact = 0
|
resolver.failsTillDeact = 0
|
||||||
resolver.reactivatePeriod = time.Microsecond * 100
|
resolver.reactivatePeriod = time.Microsecond * 100
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
@ -13,7 +15,7 @@ type MobileDependency struct {
|
|||||||
TunAdapter device.TunAdapter
|
TunAdapter device.TunAdapter
|
||||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
NetworkChangeListener listener.NetworkChangeListener
|
NetworkChangeListener listener.NetworkChangeListener
|
||||||
HostDNSAddresses []string
|
HostDNSAddresses []netip.AddrPort
|
||||||
DnsReadyListener dns.ReadyListener
|
DnsReadyListener dns.ReadyListener
|
||||||
|
|
||||||
// iOS only
|
// iOS only
|
||||||
|
@ -140,7 +140,7 @@ type RosenpassState struct {
|
|||||||
// whether it's enabled, and the last error message encountered during probing.
|
// whether it's enabled, and the last error message encountered during probing.
|
||||||
type NSGroupState struct {
|
type NSGroupState struct {
|
||||||
ID string
|
ID string
|
||||||
Servers []string
|
Servers []netip.AddrPort
|
||||||
Domains []string
|
Domains []string
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Error error
|
Error error
|
||||||
|
@ -1197,8 +1197,14 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
|||||||
if dnsState.Error != nil {
|
if dnsState.Error != nil {
|
||||||
err = dnsState.Error.Error()
|
err = dnsState.Error.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var servers []string
|
||||||
|
for _, server := range dnsState.Servers {
|
||||||
|
servers = append(servers, server.String())
|
||||||
|
}
|
||||||
|
|
||||||
pbDnsState := &proto.NSGroupState{
|
pbDnsState := &proto.NSGroupState{
|
||||||
Servers: dnsState.Servers,
|
Servers: servers,
|
||||||
Domains: dnsState.Domains,
|
Domains: dnsState.Domains,
|
||||||
Enabled: dnsState.Enabled,
|
Enabled: dnsState.Enabled,
|
||||||
Error: err,
|
Error: err,
|
||||||
|
@ -102,6 +102,11 @@ func (n *NameServer) IsEqual(other *NameServer) bool {
|
|||||||
other.Port == n.Port
|
other.Port == n.Port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddrPort returns the nameserver as a netip.AddrPort
|
||||||
|
func (n *NameServer) AddrPort() netip.AddrPort {
|
||||||
|
return netip.AddrPortFrom(n.IP, uint16(n.Port))
|
||||||
|
}
|
||||||
|
|
||||||
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
|
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
|
||||||
func ParseNameServerURL(nsURL string) (NameServer, error) {
|
func ParseNameServerURL(nsURL string) (NameServer, error) {
|
||||||
parsedURL, err := url.Parse(nsURL)
|
parsedURL, err := url.Parse(nsURL)
|
||||||
|
Reference in New Issue
Block a user