[client] Always register NetBird with plain Linux DNS and use original servers as upstream (#3967)

This commit is contained in:
Viktor Liu
2025-07-25 11:46:04 +02:00
committed by GitHub
parent af8687579b
commit cb85d3f2fc
20 changed files with 196 additions and 259 deletions

View File

@ -4,8 +4,8 @@ package dns
import ( import (
"fmt" "fmt"
"net/netip"
"os" "os"
"regexp"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -15,9 +15,6 @@ const (
defaultResolvConfPath = "/etc/resolv.conf" defaultResolvConfPath = "/etc/resolv.conf"
) )
var timeoutRegex = regexp.MustCompile(`timeout:\d+`)
var attemptsRegex = regexp.MustCompile(`attempts:\d+`)
type resolvConf struct { type resolvConf struct {
nameServers []string nameServers []string
searchDomains []string searchDomains []string
@ -108,40 +105,9 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
return rconf, nil return rconf, nil
} }
// prepareOptionsWithTimeout appends timeout to existing options if it doesn't exist,
// otherwise it adds a new option with timeout and attempts.
func prepareOptionsWithTimeout(input []string, timeout int, attempts int) []string {
configs := make([]string, len(input))
copy(configs, input)
for i, config := range configs {
if strings.HasPrefix(config, "options") {
config = strings.ReplaceAll(config, "rotate", "")
config = strings.Join(strings.Fields(config), " ")
if strings.Contains(config, "timeout:") {
config = timeoutRegex.ReplaceAllString(config, fmt.Sprintf("timeout:%d", timeout))
} else {
config = strings.Replace(config, "options ", fmt.Sprintf("options timeout:%d ", timeout), 1)
}
if strings.Contains(config, "attempts:") {
config = attemptsRegex.ReplaceAllString(config, fmt.Sprintf("attempts:%d", attempts))
} else {
config = strings.Replace(config, "options ", fmt.Sprintf("options attempts:%d ", attempts), 1)
}
configs[i] = config
return configs
}
}
return append(configs, fmt.Sprintf("options timeout:%d attempts:%d", timeout, attempts))
}
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position // removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
// and writes the file back to the original location // and writes the file back to the original location
func removeFirstNbNameserver(filename, nameserverIP string) error { func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error {
resolvConf, err := parseResolvConfFile(filename) resolvConf, err := parseResolvConfFile(filename)
if err != nil { if err != nil {
return fmt.Errorf("parse backup resolv.conf: %w", err) return fmt.Errorf("parse backup resolv.conf: %w", err)
@ -151,7 +117,7 @@ func removeFirstNbNameserver(filename, nameserverIP string) error {
return fmt.Errorf("read %s: %w", filename, err) return fmt.Errorf("read %s: %w", filename, err)
} }
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP { if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() {
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1) newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
stat, err := os.Stat(filename) stat, err := os.Stat(filename)

View File

@ -3,11 +3,13 @@
package dns package dns
import ( import (
"net/netip"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_parseResolvConf(t *testing.T) { func Test_parseResolvConf(t *testing.T) {
@ -175,52 +177,6 @@ nameserver 192.168.0.1
} }
} }
func TestPrepareOptionsWithTimeout(t *testing.T) {
tests := []struct {
name string
others []string
timeout int
attempts int
expected []string
}{
{
name: "Append new options with timeout and attempts",
others: []string{"some config"},
timeout: 2,
attempts: 2,
expected: []string{"some config", "options timeout:2 attempts:2"},
},
{
name: "Modify existing options to exclude rotate and include timeout and attempts",
others: []string{"some config", "options rotate someother"},
timeout: 3,
attempts: 2,
expected: []string{"some config", "options attempts:2 timeout:3 someother"},
},
{
name: "Existing options with timeout and attempts are updated",
others: []string{"some config", "options timeout:4 attempts:3"},
timeout: 5,
attempts: 4,
expected: []string{"some config", "options timeout:5 attempts:4"},
},
{
name: "Modify existing options, add missing attempts before timeout",
others: []string{"some config", "options timeout:4"},
timeout: 4,
attempts: 3,
expected: []string{"some config", "options attempts:3 timeout:4"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := prepareOptionsWithTimeout(tc.others, tc.timeout, tc.attempts)
assert.Equal(t, tc.expected, result)
})
}
}
func TestRemoveFirstNbNameserver(t *testing.T) { func TestRemoveFirstNbNameserver(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@ -292,7 +248,9 @@ search localdomain`,
err := os.WriteFile(tempFile, []byte(tc.content), 0644) err := os.WriteFile(tempFile, []byte(tc.content), 0644)
assert.NoError(t, err) assert.NoError(t, err)
err = removeFirstNbNameserver(tempFile, tc.ipToRemove) ip, err := netip.ParseAddr(tc.ipToRemove)
require.NoError(t, err, "Failed to parse IP address")
err = removeFirstNbNameserver(tempFile, ip)
assert.NoError(t, err) assert.NoError(t, err)
content, err := os.ReadFile(tempFile) content, err := os.ReadFile(tempFile)

View File

@ -3,6 +3,7 @@
package dns package dns
import ( import (
"net/netip"
"path" "path"
"path/filepath" "path/filepath"
"sync" "sync"
@ -22,7 +23,7 @@ var (
} }
) )
type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error type repairConfFn func([]string, netip.Addr, *resolvConf, *statemanager.Manager) error
type repair struct { type repair struct {
operationFile string operationFile string
@ -42,7 +43,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair {
} }
} }
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) { func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP netip.Addr, stateManager *statemanager.Manager) {
if f.inotify != nil { if f.inotify != nil {
return return
} }
@ -136,7 +137,7 @@ func (f *repair) isEventRelevant(event fsnotify.Event) bool {
// nbParamsAreMissing checks if the resolv.conf file contains all the parameters that NetBird needs // nbParamsAreMissing checks if the resolv.conf file contains all the parameters that NetBird needs
// check the NetBird related nameserver IP at the first place // check the NetBird related nameserver IP at the first place
// check the NetBird related search domains in the search domains list // check the NetBird related search domains in the search domains list
func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP string, rConf *resolvConf) bool { func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rConf *resolvConf) bool {
if !isContains(nbSearchDomains, rConf.searchDomains) { if !isContains(nbSearchDomains, rConf.searchDomains) {
return true return true
} }
@ -145,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP string, rConf *r
return true return true
} }
if rConf.nameServers[0] != nbNameserverIP { if rConf.nameServers[0] != nbNameserverIP.String() {
return true return true
} }

View File

@ -4,6 +4,7 @@ package dns
import ( import (
"context" "context"
"net/netip"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -105,14 +106,14 @@ nameserver 8.8.8.8`,
var changed bool var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { updateFn := func([]string, netip.Addr, *resolvConf, *statemanager.Manager) error {
changed = true changed = true
cancel() cancel()
return nil return nil
} }
r := newRepair(operationFile, updateFn) r := newRepair(operationFile, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) r.watchFileChanges([]string{"netbird.cloud"}, netip.MustParseAddr("10.0.0.1"), nil)
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755) err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
if err != nil { if err != nil {
@ -152,14 +153,14 @@ searchdomain netbird.cloud something`
var changed bool var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { updateFn := func([]string, netip.Addr, *resolvConf, *statemanager.Manager) error {
changed = true changed = true
cancel() cancel()
return nil return nil
} }
r := newRepair(tmpLink, updateFn) r := newRepair(tmpLink, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) r.watchFileChanges([]string{"netbird.cloud"}, netip.MustParseAddr("10.0.0.1"), nil)
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755) err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
if err != nil { if err != nil {

View File

@ -8,7 +8,6 @@ import (
"net/netip" "net/netip"
"os" "os"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -18,7 +17,7 @@ import (
const ( const (
fileGeneratedResolvConfContentHeader = "# Generated by NetBird" fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + ` fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n" # The original file can be restored from ` + fileDefaultResolvConfBackupLocation + "\n\n"
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird" fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
@ -26,16 +25,11 @@ const (
fileMaxNumberOfSearchDomains = 6 fileMaxNumberOfSearchDomains = 6
) )
const (
dnsFailoverTimeout = 4 * time.Second
dnsFailoverAttempts = 1
)
type fileConfigurator struct { type fileConfigurator struct {
repair *repair repair *repair
originalPerms os.FileMode
originalPerms os.FileMode nbNameserverIP netip.Addr
nbNameserverIP string originalNameservers []string
} }
func newFileConfigurator() (*fileConfigurator, error) { func newFileConfigurator() (*fileConfigurator, error) {
@ -49,22 +43,9 @@ func (f *fileConfigurator) supportCustomPort() bool {
} }
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
backupFileExist := f.isBackupFileExist() if !f.isBackupFileExist() {
if !config.RouteAll { if err := f.backup(); err != nil {
if backupFileExist { return fmt.Errorf("backup resolv.conf: %w", err)
f.repair.stopWatchFileChanges()
err := f.restore()
if err != nil {
return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
}
}
return ErrRouteAllWithoutNameserverGroup
}
if !backupFileExist {
err := f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
} }
} }
@ -76,6 +57,8 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err) log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err)
} }
f.originalNameservers = resolvConf.nameServers
f.repair.stopWatchFileChanges() f.repair.stopWatchFileChanges()
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager) err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager)
@ -86,15 +69,19 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
return nil return nil
} }
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error { // getOriginalNameservers returns the nameservers that were found in the original resolv.conf
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) func (f *fileConfigurator) getOriginalNameservers() []string {
nameServers := generateNsList(nbNameserverIP, cfg) return f.originalNameservers
}
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP netip.Addr, cfg *resolvConf, stateManager *statemanager.Manager) error {
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
buf := prepareResolvConfContent( buf := prepareResolvConfContent(
searchDomainList, searchDomainList,
nameServers, []string{nbNameserverIP.String()},
options) cfg.others,
)
log.Debugf("creating managed file %s", defaultResolvConfPath) log.Debugf("creating managed file %s", defaultResolvConfPath)
err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms) err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
@ -197,38 +184,28 @@ func restoreResolvConfFile() error {
return nil return nil
} }
// generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list
func generateNsList(nbNameserverIP string, cfg *resolvConf) []string {
ns := make([]string, 1, len(cfg.nameServers)+1)
ns[0] = nbNameserverIP
for _, cfgNs := range cfg.nameServers {
if nbNameserverIP != cfgNs {
ns = append(ns, cfgNs)
}
}
return ns
}
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer { func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine) buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
for _, cfgLine := range others { for _, cfgLine := range others {
buf.WriteString(cfgLine) buf.WriteString(cfgLine)
buf.WriteString("\n") buf.WriteByte('\n')
} }
if len(searchDomains) > 0 { if len(searchDomains) > 0 {
buf.WriteString("search ") buf.WriteString("search ")
buf.WriteString(strings.Join(searchDomains, " ")) buf.WriteString(strings.Join(searchDomains, " "))
buf.WriteString("\n") buf.WriteByte('\n')
} }
for _, ns := range nameServers { for _, ns := range nameServers {
buf.WriteString("nameserver ") buf.WriteString("nameserver ")
buf.WriteString(ns) buf.WriteString(ns)
buf.WriteString("\n") buf.WriteByte('\n')
} }
return buf return buf
} }

View File

@ -15,6 +15,7 @@ const (
PriorityDNSRoute = 75 PriorityDNSRoute = 75
PriorityUpstream = 50 PriorityUpstream = 50
PriorityDefault = 1 PriorityDefault = 1
PriorityFallback = -100
) )
type SubdomainMatcher interface { type SubdomainMatcher interface {
@ -191,7 +192,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// No handler matched or all handlers passed // No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname) log.Tracef("no handler found for domain=%s", qname)
resp := &dns.Msg{} resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError) resp.SetRcode(r, dns.RcodeRefused)
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }

View File

@ -11,8 +11,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
const ( const (
ipv4ReverseZone = ".in-addr.arpa." ipv4ReverseZone = ".in-addr.arpa."
ipv6ReverseZone = ".ip6.arpa." ipv6ReverseZone = ".ip6.arpa."
@ -27,14 +25,14 @@ type hostManager interface {
type SystemDNSSettings struct { type SystemDNSSettings struct {
Domains []string Domains []string
ServerIP string ServerIP netip.Addr
ServerPort int ServerPort int
} }
type HostDNSConfig struct { type HostDNSConfig struct {
Domains []DomainConfig `json:"domains"` Domains []DomainConfig `json:"domains"`
RouteAll bool `json:"routeAll"` RouteAll bool `json:"routeAll"`
ServerIP string `json:"serverIP"` ServerIP netip.Addr `json:"serverIP"`
ServerPort int `json:"serverPort"` ServerPort int `json:"serverPort"`
} }
@ -89,7 +87,7 @@ func newNoopHostMocker() hostManager {
} }
} }
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostDNSConfig { func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) HostDNSConfig {
config := HostDNSConfig{ config := HostDNSConfig{
RouteAll: false, RouteAll: false,
ServerIP: ip, ServerIP: ip,

View File

@ -7,7 +7,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"net" "net/netip"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
@ -165,13 +165,13 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
} }
func (s *systemConfigurator) addLocalDNS() error { func (s *systemConfigurator) addLocalDNS() error {
if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
err := s.recordSystemDNSSettings(true) err := s.recordSystemDNSSettings(true)
log.Errorf("Unable to get system DNS configuration") log.Errorf("Unable to get system DNS configuration")
return err return err
} }
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 { if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil { if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err) return fmt.Errorf("couldn't add local network DNS conf: %w", err)
@ -184,7 +184,7 @@ func (s *systemConfigurator) addLocalDNS() error {
} }
func (s *systemConfigurator) recordSystemDNSSettings(force bool) error { func (s *systemConfigurator) recordSystemDNSSettings(force bool) error {
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force { if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 && !force {
return nil return nil
} }
@ -238,8 +238,8 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray { } else if inServerAddressesArray {
address := strings.Split(line, " : ")[1] address := strings.Split(line, " : ")[1]
if ip := net.ParseIP(address); ip != nil && ip.To4() != nil { if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = address dnsSettings.ServerIP = ip
inServerAddressesArray = false // Stop reading after finding the first IPv4 address inServerAddressesArray = false // Stop reading after finding the first IPv4 address
} }
} }
@ -250,12 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} }
// default to 53 port // default to 53 port
dnsSettings.ServerPort = 53 dnsSettings.ServerPort = defaultPort
return dnsSettings, nil return dnsSettings, nil
} }
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true) err := s.addDNSState(key, domains, ip, port, true)
if err != nil { if err != nil {
return fmt.Errorf("add dns state: %w", err) return fmt.Errorf("add dns state: %w", err)
@ -268,7 +268,7 @@ func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, po
return nil return nil
} }
func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error { func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
err := s.addDNSState(key, domains, dnsServer, port, false) err := s.addDNSState(key, domains, dnsServer, port, false)
if err != nil { if err != nil {
return fmt.Errorf("add dns state: %w", err) return fmt.Errorf("add dns state: %w", err)
@ -281,14 +281,14 @@ func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, por
return nil return nil
} }
func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port int, enableSearch bool) error { func (s *systemConfigurator) addDNSState(state, domains string, dnsServer netip.Addr, port int, enableSearch bool) error {
noSearch := "1" noSearch := "1"
if enableSearch { if enableSearch {
noSearch = "0" noSearch = "0"
} }
lines := buildAddCommandLine(keySupplementalMatchDomains, arraySymbol+domains) lines := buildAddCommandLine(keySupplementalMatchDomains, arraySymbol+domains)
lines += buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+noSearch) lines += buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+noSearch)
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer) lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer.String())
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(state, lines) addDomainCommand := buildCreateStateWithOperation(state, lines)

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/netip"
"os/exec" "os/exec"
"strings" "strings"
"syscall" "syscall"
@ -210,8 +211,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil return nil
} }
func (r *registryConfigurator) addDNSSetupForAll(ip string) error { func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip); err != nil { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
return fmt.Errorf("adding dns setup for all failed: %w", err) return fmt.Errorf("adding dns setup for all failed: %w", err)
} }
r.routingAll = true r.routingAll = true
@ -219,7 +220,7 @@ func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
return nil return nil
} }
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) error {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
if r.gpo { if r.gpo {
@ -241,7 +242,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
} }
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path // configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip string) error { func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
return fmt.Errorf("remove existing dns policy: %w", err) return fmt.Errorf("remove existing dns policy: %w", err)
} }
@ -260,7 +261,7 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
return fmt.Errorf("set %s: %w", dnsPolicyConfigNameKey, err) return fmt.Errorf("set %s: %w", dnsPolicyConfigNameKey, err)
} }
if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip); err != nil { if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip.String()); err != nil {
return fmt.Errorf("set %s: %w", dnsPolicyConfigGenericDNSServersKey, err) return fmt.Errorf("set %s: %w", dnsPolicyConfigGenericDNSServersKey, err)
} }

View File

@ -2,6 +2,7 @@ package dns
import ( import (
"fmt" "fmt"
"net/netip"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -45,8 +46,8 @@ func (m *MockServer) Stop() {
} }
} }
func (m *MockServer) DnsIP() string { func (m *MockServer) DnsIP() netip.Addr {
return "" return netip.MustParseAddr("100.10.254.255")
} }
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) { func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {

View File

@ -110,11 +110,7 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
connSettings.cleanDeprecatedSettings() connSettings.cleanDeprecatedSettings()
dnsIP, err := netip.ParseAddr(config.ServerIP) convDNSIP := binary.LittleEndian.Uint32(config.ServerIP.AsSlice())
if err != nil {
return fmt.Errorf("unable to parse ip address, error: %w", err)
}
convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice())
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP})
var ( var (
searchDomains []string searchDomains []string

View File

@ -46,9 +46,9 @@ type resolvconf struct {
func detectResolvconfType() (resolvconfType, error) { func detectResolvconfType() (resolvconfType, error) {
cmd := exec.Command(resolvconfCommand, "--version") cmd := exec.Command(resolvconfCommand, "--version")
out, err := cmd.Output() out, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return typeOpenresolv, fmt.Errorf("failed to determine resolvconf type: %w", err) return typeOpenresolv, fmt.Errorf("determine resolvconf type: %w", err)
} }
if strings.Contains(string(out), "openresolv") { if strings.Contains(string(out), "openresolv") {
@ -66,7 +66,7 @@ func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
implType, err := detectResolvconfType() implType, err := detectResolvconfType()
if err != nil { if err != nil {
log.Warnf("failed to detect resolvconf type, defaulting to openresolv: %v", err) log.Warnf("failed to detect resolvconf type, defaulting to openresolv: %v", err)
implType = typeOpenresolv implType = typeResolvconf
} else { } else {
log.Infof("detected resolvconf type: %v", implType) log.Infof("detected resolvconf type: %v", implType)
} }
@ -85,24 +85,14 @@ func (r *resolvconf) supportCustomPort() bool {
} }
func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
if !config.RouteAll {
err = r.restoreHostDNS()
if err != nil {
log.Errorf("restore host dns: %s", err)
}
return ErrRouteAllWithoutNameserverGroup
}
searchDomainList := searchDomains(config) searchDomainList := searchDomains(config)
searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains) searchDomainList = mergeSearchDomains(searchDomainList, r.originalSearchDomains)
options := prepareOptionsWithTimeout(r.othersConfigs, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
buf := prepareResolvConfContent( buf := prepareResolvConfContent(
searchDomainList, searchDomainList,
append([]string{config.ServerIP}, r.originalNameServers...), []string{config.ServerIP.String()},
options) r.othersConfigs,
)
state := &ShutdownState{ state := &ShutdownState{
ManagerType: resolvConfManager, ManagerType: resolvConfManager,
@ -112,8 +102,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
log.Errorf("failed to update shutdown state: %s", err) log.Errorf("failed to update shutdown state: %s", err)
} }
err = r.applyConfig(buf) if err := r.applyConfig(buf); err != nil {
if err != nil {
return fmt.Errorf("apply config: %w", err) return fmt.Errorf("apply config: %w", err)
} }
@ -121,6 +110,10 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
return nil return nil
} }
func (r *resolvconf) getOriginalNameservers() []string {
return r.originalNameServers
}
func (r *resolvconf) restoreHostDNS() error { func (r *resolvconf) restoreHostDNS() error {
var cmd *exec.Cmd var cmd *exec.Cmd
@ -157,7 +150,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
} }
cmd.Stdin = &content cmd.Stdin = &content
out, err := cmd.Output() out, err := cmd.CombinedOutput()
log.Tracef("resolvconf output: %s", out) log.Tracef("resolvconf output: %s", out)
if err != nil { if err != nil {
return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err) return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)

View File

@ -2,7 +2,6 @@ package dns
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"runtime" "runtime"
@ -20,7 +19,6 @@ import (
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
) )
@ -41,7 +39,7 @@ type Server interface {
DeregisterHandler(domains domain.List, priority int) DeregisterHandler(domains domain.List, priority int)
Initialize() error Initialize() error
Stop() Stop()
DnsIP() string DnsIP() netip.Addr
UpdateDNSServer(serial uint64, update nbdns.Config) error UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string) OnUpdatedHostDNSServer(strings []string)
SearchDomains() []string SearchDomains() []string
@ -53,6 +51,12 @@ type nsGroupsByDomain struct {
groups []*nbdns.NameServerGroup groups []*nbdns.NameServerGroup
} }
// hostManagerWithOriginalNS extends the basic hostManager interface
type hostManagerWithOriginalNS interface {
hostManager
getOriginalNameservers() []string
}
// DefaultServer dns server object // DefaultServer dns server object
type DefaultServer struct { type DefaultServer struct {
ctx context.Context ctx context.Context
@ -215,6 +219,7 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
log.Warn("skipping empty domain") log.Warn("skipping empty domain")
continue continue
} }
s.handlerChain.AddHandler(domain, handler, priority) s.handlerChain.AddHandler(domain, handler, priority)
} }
} }
@ -286,7 +291,7 @@ func (s *DefaultServer) Initialize() (err error) {
// //
// When kernel space interface used it return real DNS server listener IP address // When kernel space interface used it return real DNS server listener IP address
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network) // For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
func (s *DefaultServer) DnsIP() string { func (s *DefaultServer) DnsIP() netip.Addr {
return s.service.RuntimeIP() return s.service.RuntimeIP()
} }
@ -297,6 +302,11 @@ func (s *DefaultServer) Stop() {
s.ctxCancel() s.ctxCancel()
if s.hostManager != nil { if s.hostManager != nil {
if srvs, ok := s.hostManager.(hostManagerWithOriginalNS); ok && len(srvs.getOriginalNameservers()) > 0 {
log.Debugf("deregistering original nameservers as fallback handlers")
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
}
if err := s.hostManager.restoreHostDNS(); err != nil { if err := s.hostManager.restoreHostDNS(); err != nil {
log.Error("failed to restore host DNS settings: ", err) log.Error("failed to restore host DNS settings: ", err)
} else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil { } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil {
@ -311,7 +321,6 @@ func (s *DefaultServer) Stop() {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
s.hostsDNSHolder.set(hostsDnsList) s.hostsDNSHolder.set(hostsDnsList)
@ -493,25 +502,56 @@ func (s *DefaultServer) applyHostConfig() {
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err) log.Errorf("failed to apply DNS host manager update: %v", err)
s.handleErrNoGroupaAll(err)
} }
s.registerFallback(config)
} }
func (s *DefaultServer) handleErrNoGroupaAll(err error) { // registerFallback registers original nameservers as low-priority fallback handlers
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) { func (s *DefaultServer) registerFallback(config HostDNSConfig) {
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
if !ok {
return return
} }
if s.statusRecorder == nil { originalNameservers := hostMgrWithNS.getOriginalNameservers()
if len(originalNameservers) == 0 {
return return
} }
s.statusRecorder.PublishEvent( log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
cProto.SystemEvent_WARNING, cProto.SystemEvent_DNS,
"The host dns manager does not support match domains", handler, err := newUpstreamResolver(
"The host dns manager does not support match domains without a catch-all nameserver group.", s.ctx,
map[string]string{"manager": s.hostManager.string()}, s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,
) )
if err != nil {
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
return
}
for _, ns := range originalNameservers {
if ns == config.ServerIP.String() {
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
continue
}
ns = fmt.Sprintf("%s:%d", ns, defaultPort)
if ip, err := netip.ParseAddr(ns); err == nil && ip.Is6() {
ns = fmt.Sprintf("[%s]:%d", ns, defaultPort)
}
handler.upstreamServers = append(handler.upstreamServers, ns)
}
handler.deactivate = func(error) { /* always active */ }
handler.reactivate = func() { /* always active */ }
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
} }
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) { func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
@ -588,14 +628,8 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
priority := basePriority - i priority := basePriority - i
// Check if we're about to overlap with the next priority tier. // Check if we're about to overlap with the next priority tier
// This boundary check ensures that the priority of upstream handlers does not conflict if s.leaksPriority(domainGroup, basePriority, priority) {
// with the default priority tier. By decrementing the priority for each handler, we avoid
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
// handlers to maintain the integrity of the priority system.
if basePriority == PriorityUpstream && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityUpstream-PriorityDefault)
break break
} }
@ -648,6 +682,21 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
return muxUpdates, nil return muxUpdates, nil
} }
func (s *DefaultServer) leaksPriority(domainGroup nsGroupsByDomain, basePriority int, priority int) bool {
if basePriority == PriorityUpstream && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityUpstream-PriorityDefault)
return true
}
if basePriority == PriorityDefault && priority <= PriorityFallback {
log.Warnf("too many handlers for domain=%s, would overlap with fallback priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityDefault-PriorityFallback)
return true
}
return false
}
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests // this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap { for _, existing := range s.dnsMuxMap {
@ -760,6 +809,12 @@ func (s *DefaultServer) upstreamCallbacks(
} }
func (s *DefaultServer) addHostRootZone() { func (s *DefaultServer) addHostRootZone() {
hostDNSServers := s.hostsDNSHolder.get()
if len(hostDNSServers) == 0 {
log.Debug("no host DNS servers available, skipping root zone handler creation")
return
}
handler, err := newUpstreamResolver( handler, err := newUpstreamResolver(
s.ctx, s.ctx,
s.wgInterface.Name(), s.wgInterface.Name(),
@ -775,7 +830,7 @@ func (s *DefaultServer) addHostRootZone() {
} }
handler.upstreamServers = make([]string, 0) handler.upstreamServers = make([]string, 0)
for k := range s.hostsDNSHolder.get() { for k := range hostDNSServers {
handler.upstreamServers = append(handler.upstreamServers, k) handler.upstreamServers = append(handler.upstreamServers, k)
} }
handler.deactivate = func(error) {} handler.deactivate = func(error) {}

View File

@ -938,7 +938,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return wgIface, nil return wgIface, nil
} }
func newDnsResolver(ip string, port int) *net.Resolver { func newDnsResolver(ip netip.Addr, port int) *net.Resolver {
return &net.Resolver{ return &net.Resolver{
PreferGo: true, PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
@ -1047,7 +1047,7 @@ type mockService struct{}
func (m *mockService) Listen() error { return nil } func (m *mockService) Listen() error { return nil }
func (m *mockService) Stop() {} func (m *mockService) Stop() {}
func (m *mockService) RuntimeIP() string { return "127.0.0.1" } func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
func (m *mockService) RuntimePort() int { return 53 } func (m *mockService) RuntimePort() int { return 53 }
func (m *mockService) RegisterMux(string, dns.Handler) {} func (m *mockService) RegisterMux(string, dns.Handler) {}
func (m *mockService) DeregisterMux(string) {} func (m *mockService) DeregisterMux(string) {}

View File

@ -1,6 +1,8 @@
package dns package dns
import ( import (
"net/netip"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -14,5 +16,5 @@ type service interface {
RegisterMux(domain string, handler dns.Handler) RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string) DeregisterMux(key string)
RuntimePort() int RuntimePort() int
RuntimeIP() string RuntimeIP() netip.Addr
} }

View File

@ -18,8 +18,11 @@ import (
const ( const (
customPort = 5053 customPort = 5053
defaultIP = "127.0.0.1" )
customIP = "127.0.0.153"
var (
defaultIP = netip.MustParseAddr("127.0.0.1")
customIP = netip.MustParseAddr("127.0.0.153")
) )
type serviceViaListener struct { type serviceViaListener struct {
@ -27,7 +30,7 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux dnsMux *dns.ServeMux
customAddr *netip.AddrPort customAddr *netip.AddrPort
server *dns.Server server *dns.Server
listenIP string listenIP netip.Addr
listenPort uint16 listenPort uint16
listenerIsRunning bool listenerIsRunning bool
listenerFlagLock sync.Mutex listenerFlagLock sync.Mutex
@ -65,6 +68,7 @@ func (s *serviceViaListener) Listen() error {
log.Errorf("failed to eval runtime address: %s", err) log.Errorf("failed to eval runtime address: %s", err)
return fmt.Errorf("eval listen address: %w", err) return fmt.Errorf("eval listen address: %w", err)
} }
s.listenIP = s.listenIP.Unmap()
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort) s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
log.Debugf("starting dns on %s", s.server.Addr) log.Debugf("starting dns on %s", s.server.Addr)
go func() { go func() {
@ -124,7 +128,7 @@ func (s *serviceViaListener) RuntimePort() int {
} }
} }
func (s *serviceViaListener) RuntimeIP() string { func (s *serviceViaListener) RuntimeIP() netip.Addr {
return s.listenIP return s.listenIP
} }
@ -139,9 +143,9 @@ func (s *serviceViaListener) setListenerStatus(running bool) {
// first check the 53 port availability on WG interface or lo, if not success // first check the 53 port availability on WG interface or lo, if not success
// pick a random port on WG interface for eBPF, if not success // pick a random port on WG interface for eBPF, if not success
// check the 5053 port availability on WG interface or lo without eBPF usage, // check the 5053 port availability on WG interface or lo without eBPF usage,
func (s *serviceViaListener) evalListenAddress() (string, uint16, error) { func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
if s.customAddr != nil { if s.customAddr != nil {
return s.customAddr.Addr().String(), s.customAddr.Port(), nil return s.customAddr.Addr(), s.customAddr.Port(), nil
} }
ip, ok := s.testFreePort(defaultPort) ip, ok := s.testFreePort(defaultPort)
@ -152,7 +156,7 @@ func (s *serviceViaListener) evalListenAddress() (string, uint16, error) {
ebpfSrv, port, ok := s.tryToUseeBPF() ebpfSrv, port, ok := s.tryToUseeBPF()
if ok { if ok {
s.ebpfService = ebpfSrv s.ebpfService = ebpfSrv
return s.wgInterface.Address().IP.String(), port, nil return s.wgInterface.Address().IP, port, nil
} }
ip, ok = s.testFreePort(customPort) ip, ok = s.testFreePort(customPort)
@ -160,15 +164,15 @@ func (s *serviceViaListener) evalListenAddress() (string, uint16, error) {
return ip, customPort, nil return ip, customPort, nil
} }
return "", 0, fmt.Errorf("failed to find a free port for DNS server") return netip.Addr{}, 0, fmt.Errorf("failed to find a free port for DNS server")
} }
func (s *serviceViaListener) testFreePort(port int) (string, bool) { func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
var ips []string var ips []netip.Addr
if runtime.GOOS != "darwin" { if runtime.GOOS != "darwin" {
ips = []string{s.wgInterface.Address().IP.String(), defaultIP, customIP} ips = []netip.Addr{s.wgInterface.Address().IP, defaultIP, customIP}
} else { } else {
ips = []string{defaultIP, customIP} ips = []netip.Addr{defaultIP, customIP}
} }
for _, ip := range ips { for _, ip := range ips {
@ -178,10 +182,10 @@ func (s *serviceViaListener) testFreePort(port int) (string, bool) {
return ip, true return ip, true
} }
return "", false return netip.Addr{}, false
} }
func (s *serviceViaListener) tryToBind(ip string, port int) bool { func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
addrString := fmt.Sprintf("%s:%d", ip, port) addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString)) udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr) probeListener, err := net.ListenUDP("udp", udpAddr)
@ -224,7 +228,7 @@ func (s *serviceViaListener) tryToUseeBPF() (ebpfMgr.Manager, uint16, bool) {
} }
func (s *serviceViaListener) generateFreePort() (uint16, error) { func (s *serviceViaListener) generateFreePort() (uint16, error) {
ok := s.tryToBind(s.wgInterface.Address().IP.String(), customPort) ok := s.tryToBind(s.wgInterface.Address().IP, customPort)
if ok { if ok {
return customPort, nil return customPort, nil
} }

View File

@ -16,7 +16,7 @@ import (
type ServiceViaMemory struct { type ServiceViaMemory struct {
wgInterface WGIface wgInterface WGIface
dnsMux *dns.ServeMux dnsMux *dns.ServeMux
runtimeIP string runtimeIP netip.Addr
runtimePort int runtimePort int
udpFilterHookID string udpFilterHookID string
listenerIsRunning bool listenerIsRunning bool
@ -32,7 +32,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
wgInterface: wgIface, wgInterface: wgIface,
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
runtimeIP: lastIP.String(), runtimeIP: lastIP,
runtimePort: defaultPort, runtimePort: defaultPort,
} }
return s return s
@ -84,7 +84,7 @@ func (s *ServiceViaMemory) RuntimePort() int {
return s.runtimePort return s.runtimePort
} }
func (s *ServiceViaMemory) RuntimeIP() string { func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
return s.runtimeIP return s.runtimeIP
} }
@ -121,10 +121,5 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true return true
} }
ip, err := netip.ParseAddr(s.runtimeIP) return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
if err != nil {
return "", fmt.Errorf("parse runtime ip: %w", err)
}
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
} }

View File

@ -89,21 +89,16 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
} }
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
parsedIP, err := netip.ParseAddr(config.ServerIP)
if err != nil {
return fmt.Errorf("unable to parse ip address, error: %w", err)
}
ipAs4 := parsedIP.As4()
defaultLinkInput := systemdDbusDNSInput{ defaultLinkInput := systemdDbusDNSInput{
Family: unix.AF_INET, Family: unix.AF_INET,
Address: ipAs4[:], Address: config.ServerIP.AsSlice(),
} }
if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil { if err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err) return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err)
} }
// We don't support dnssec. On some machines this is default on so we explicitly set it to off // We don't support dnssec. On some machines this is default on so we explicitly set it to off
if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil { if err := s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSSEC to 'no': %v", err) log.Warnf("failed to set DNSSEC to 'no': %v", err)
} }
@ -129,8 +124,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
} }
if config.RouteAll { if config.RouteAll {
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true) if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true); err != nil {
if err != nil {
return fmt.Errorf("set link as default dns router: %w", err) return fmt.Errorf("set link as default dns router: %w", err)
} }
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
@ -139,7 +133,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
}) })
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
} else { } else {
if err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil { if err := s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil {
return fmt.Errorf("remove link as default dns router: %w", err) return fmt.Errorf("remove link as default dns router: %w", err)
} }
} }
@ -153,9 +147,8 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
} }
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
err = s.setDomainsForInterface(domainsInput) if err := s.setDomainsForInterface(domainsInput); err != nil {
if err != nil { log.Error("failed to set domains for interface: ", err)
log.Error(err)
} }
if err := s.flushDNSCache(); err != nil { if err := s.flushDNSCache(); err != nil {

View File

@ -35,12 +35,7 @@ func (s *ShutdownState) Cleanup() error {
} }
// TODO: move file contents to state manager // TODO: move file contents to state manager
func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error { func createUncleanShutdownIndicator(sourcePath string, dnsAddress netip.Addr, stateManager *statemanager.Manager) error {
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
if err != nil {
return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err)
}
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation) dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err) return fmt.Errorf("create dir %s: %w", dir, err)

View File

@ -1550,7 +1550,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
func (e *Engine) wgInterfaceCreate() (err error) { func (e *Engine) wgInterfaceCreate() (err error) {
switch runtime.GOOS { switch runtime.GOOS {
case "android": case "android":
err = e.wgInterface.CreateOnAndroid(e.routeManager.InitialRouteRange(), e.dnsServer.DnsIP(), e.dnsServer.SearchDomains()) err = e.wgInterface.CreateOnAndroid(e.routeManager.InitialRouteRange(), e.dnsServer.DnsIP().String(), e.dnsServer.SearchDomains())
case "ios": case "ios":
e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr) e.mobileDep.NetworkChangeListener.SetInterfaceIP(e.config.WgAddr)
err = e.wgInterface.Create() err = e.wgInterface.Create()