Feature/permanent dns (#967)

* Add DNS list argument for mobile client

* Write testable code

Many places are checked the wgInterface != nil condition.
It is doing it just because to avoid the real wgInterface creation for tests.
Instead of this involve a wgInterface interface what is moc-able.

* Refactor the DNS server internal code structure

With the fake resolver has been involved several
if-else statement and generated some unused
variables to distinguish the listener and fake
resolver solutions at running time. With this
commit the fake resolver and listener based
solution has been moved into two separated
structure. Name of this layer is the 'service'.
With this modification the unit test looks
simpler and open the option to add new logic for
the permanent DNS service usage for mobile
systems.



* Remove is running check in test

We can not ensure the state well so remove this
check. The test will fail if the server is not
running well.
This commit is contained in:
Zoltan Papp 2023-07-14 21:56:22 +02:00 committed by GitHub
parent 9c2c0e7934
commit 7ebe58f20a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 925 additions and 366 deletions

View File

@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@ -35,6 +36,11 @@ type RouteListener interface {
routemanager.RouteListener routemanager.RouteListener
} }
// DnsReadyListener export internal dns ReadyListener for mobile
type DnsReadyListener interface {
dns.ReadyListener
}
func init() { func init() {
formatter.SetLogcatFormatter(log.StandardLogger()) formatter.SetLogcatFormatter(log.StandardLogger())
} }
@ -49,6 +55,7 @@ type Client struct {
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
deviceName string deviceName string
routeListener routemanager.RouteListener routeListener routemanager.RouteListener
onHostDnsFn func([]string)
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@ -65,7 +72,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener) error { func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@ -90,7 +97,8 @@ func (c *Client) Run(urlOpener URLOpener) error {
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener) c.onHostDnsFn = func([]string) {}
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@ -126,6 +134,17 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos} return &PeerInfoArray{items: peerInfos}
} }
// OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns()
if err != nil {
return err
}
dnsServer.OnUpdatedHostDNSServer(list.items)
return nil
}
// SetConnectionListener set the network connection listener // SetConnectionListener set the network connection listener
func (c *Client) SetConnectionListener(listener ConnectionListener) { func (c *Client) SetConnectionListener(listener ConnectionListener) {
c.recorder.SetConnectionListener(listener) c.recorder.SetConnectionListener(listener)

View File

@ -0,0 +1,26 @@
package android
import "fmt"
// DNSList is a wrapper of []string
type DNSList struct {
items []string
}
// Add new DNS address to the collection
func (array *DNSList) Add(s string) {
array.items = append(array.items, s)
}
// Get return an element of the collection
func (array *DNSList) Get(i int) (string, error) {
if i >= len(array.items) || i < 0 {
return "", fmt.Errorf("out of range")
}
return array.items[i], nil
}
// Size return with the size of the collection
func (array *DNSList) Size() int {
return len(array.items)
}

View File

@ -0,0 +1,24 @@
package android
import "testing"
func TestDNSList_Get(t *testing.T) {
l := DNSList{
items: make([]string, 1),
}
_, err := l.Get(0)
if err != nil {
t.Errorf("invalid error: %s", err)
}
_, err = l.Get(-1)
if err == nil {
t.Errorf("expected error but got nil")
}
_, err = l.Get(1)
if err == nil {
t.Errorf("expected error but got nil")
}
}

View File

@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx) ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil) return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {

View File

@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@ -24,7 +25,24 @@ import (
) )
// RunClient with main logic. // RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error { func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{})
}
// RunClientMobile with main logic on mobile system
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency)
}
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@ -151,14 +169,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return wrapErr(err) return wrapErr(err)
} }
// in case of non Android os these variables will be nil engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
md := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
}
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
err = engine.Start() err = engine.Start()
if err != nil { if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)

View File

@ -1,13 +1,9 @@
package dns package dns
import (
"github.com/netbirdio/netbird/iface"
)
type androidHostManager struct { type androidHostManager struct {
} }
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
return &androidHostManager{}, nil return &androidHostManager{}, nil
} }

View File

@ -9,8 +9,6 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@ -34,7 +32,7 @@ type systemConfigurator struct {
createdKeys map[string]struct{} createdKeys map[string]struct{}
} }
func newHostManager(_ *iface.WGIface) (hostManager, error) { func newHostManager(_ WGIface) (hostManager, error) {
return &systemConfigurator{ return &systemConfigurator{
createdKeys: make(map[string]struct{}), createdKeys: make(map[string]struct{}),
}, nil }, nil

View File

@ -5,10 +5,10 @@ package dns
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os" "os"
"strings" "strings"
log "github.com/sirupsen/logrus"
) )
const ( const (
@ -25,7 +25,7 @@ const (
type osManagerType int type osManagerType int
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -6,8 +6,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@ -33,7 +31,7 @@ type registryConfigurator struct {
existingSearchDomains []string existingSearchDomains []string
} }
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
guid, err := wgInterface.GetInterfaceGUIDString() guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -31,6 +31,11 @@ func (m *MockServer) DnsIP() string {
return "" return ""
} }
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
//TODO implement me
panic("implement me")
}
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface // UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
if m.UpdateDNSServerFunc != nil { if m.UpdateDNSServerFunc != nil {

View File

@ -14,8 +14,6 @@ import (
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@ -72,7 +70,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
} }
} }
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -8,8 +8,6 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
) )
const resolvconfCommand = "resolvconf" const resolvconfCommand = "resolvconf"
@ -18,7 +16,7 @@ type resolvconf struct {
ifaceName string ifaceName string
} }
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
return &resolvconf{ return &resolvconf{
ifaceName: wgInterface.Name(), ifaceName: wgInterface.Name(),
}, nil }, nil

View File

@ -3,29 +3,20 @@ package dns
import ( import (
"context" "context"
"fmt" "fmt"
"math/big"
"net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
) )
const ( // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
defaultPort = 53 type ReadyListener interface {
customPort = 5053 OnReady()
defaultIP = "127.0.0.1" }
customIP = "127.0.0.153"
)
// Server is a dns server interface // Server is a dns server interface
type Server interface { type Server interface {
@ -33,6 +24,7 @@ type Server interface {
Stop() Stop()
DnsIP() string DnsIP() string
UpdateDNSServer(serial uint64, update nbdns.Config) error UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string)
} }
type registeredHandlerMap map[string]handlerWithStop type registeredHandlerMap map[string]handlerWithStop
@ -42,21 +34,19 @@ type DefaultServer struct {
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
mux sync.Mutex mux sync.Mutex
udpFilterHookID string service service
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registeredHandlerMap dnsMuxMap registeredHandlerMap
localResolver *localResolver localResolver *localResolver
wgInterface *iface.WGIface wgInterface WGIface
hostManager hostManager hostManager hostManager
updateSerial uint64 updateSerial uint64
listenerIsRunning bool
runtimePort int
runtimeIP string
previousConfigHash uint64 previousConfigHash uint64
currentConfig hostDNSConfig currentConfig hostDNSConfig
customAddress *netip.AddrPort
enabled bool // permanent related properties
permanent bool
hostsDnsList []string
hostsDnsListLock sync.Mutex
} }
type handlerWithStop interface { type handlerWithStop interface {
@ -70,9 +60,7 @@ type muxUpdate struct {
} }
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) { func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
mux := dns.NewServeMux()
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@ -82,37 +70,44 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
addrPort = &parsedAddrPort addrPort = &parsedAddrPort
} }
ctx, stop := context.WithCancel(ctx) var dnsService service
if wgInterface.IsUserspaceBind() {
dnsService = newServiceViaMemory(wgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
}
return newDefaultServer(ctx, wgInterface, dnsService), nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone()
setServerDns(ds)
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
server: &dns.Server{ service: dnsService,
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
dnsMux: mux,
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
wgInterface: wgInterface, wgInterface: wgInterface,
customAddress: addrPort,
} }
if initialDnsCfg != nil { return defaultServer
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
}
if wgInterface.IsUserspaceBind() {
defaultServer.evelRuntimeAddressForUserspace()
}
return defaultServer, nil
} }
// Initialize instantiate host manager. It required to be initialized wginterface // Initialize instantiate host manager and the dns service
func (s *DefaultServer) Initialize() (err error) { func (s *DefaultServer) Initialize() (err error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@ -121,72 +116,23 @@ func (s *DefaultServer) Initialize() (err error) {
return nil return nil
} }
if !s.wgInterface.IsUserspaceBind() { if s.permanent {
s.evalRuntimeAddress() err = s.service.Listen()
if err != nil {
return err
} }
}
s.hostManager, err = newHostManager(s.wgInterface) s.hostManager, err = newHostManager(s.wgInterface)
return return
} }
// listen runs the listener in a go routine
func (s *DefaultServer) listen() {
// nil check required in unit tests
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.udpFilterHookID = s.filterDNSTraffic()
s.setListenerStatus(true)
return
}
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
}
// DnsIP returns the DNS resolver server IP address // DnsIP returns the DNS resolver server IP address
// //
// When kernel space interface used it return real DNS server listener IP address // When kernel space interface used it return real DNS server listener IP address
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network) // For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
func (s *DefaultServer) DnsIP() string { func (s *DefaultServer) DnsIP() string {
if !s.enabled { return s.service.RuntimeIP()
return ""
}
return s.runtimeIP
}
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" && s.wgInterface != nil {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *DefaultServer) setListenerStatus(running bool) {
s.listenerIsRunning = running
} }
// Stop stops the server // Stop stops the server
@ -202,37 +148,23 @@ func (s *DefaultServer) Stop() {
} }
} }
err := s.stopListener() s.service.Stop()
if err != nil {
log.Error(err)
}
} }
func (s *DefaultServer) stopListener() error { // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning { // It will be applied if the mgm server do not enforce DNS settings for root zone
// udpFilterHookID here empty only in the unit tests func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" { s.hostsDnsListLock.Lock()
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil { defer s.hostsDnsListLock.Unlock()
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}
s.udpFilterHookID = ""
s.listenerIsRunning = false
return nil
}
if !s.listenerIsRunning { s.hostsDnsList = hostsDnsList
return nil _, ok := s.dnsMuxMap[nbdns.RootZone]
if ok {
log.Debugf("on new host DNS config but skip to apply it")
return
} }
log.Debugf("update host DNS settings: %+v", hostsDnsList)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) s.addHostRootZone()
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
}
return nil
} }
// UpdateDNSServer processes an update received from the management service // UpdateDNSServer processes an update received from the management service
@ -283,12 +215,10 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener or fake resolver // is the service should be disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records // and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable { if update.ServiceEnable {
if err := s.stopListener(); err != nil { _ = s.service.Listen()
log.Error(err) } else if !s.permanent {
} s.service.Stop()
} else if !s.listenerIsRunning {
s.listen()
} }
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
@ -299,15 +229,14 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
if err != nil { if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err) return fmt.Errorf("not applying dns update, error: %v", err)
} }
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates) s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords) s.updateLocalResolver(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
hostUpdate := s.currentConfig hostUpdate := s.currentConfig
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver") "Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
hostUpdate.routeAll = false hostUpdate.routeAll = false
@ -412,19 +341,32 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registeredHandlerMap) muxUpdateMap := make(registeredHandlerMap)
var isContainRootUpdate bool
for _, update := range muxUpdates { for _, update := range muxUpdates {
s.registerMux(update.domain, update.handler) s.service.RegisterMux(update.domain, update.handler)
muxUpdateMap[update.domain] = update.handler muxUpdateMap[update.domain] = update.handler
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop() existingHandler.stop()
} }
if update.domain == nbdns.RootZone {
isContainRootUpdate = true
}
} }
for key, existingHandler := range s.dnsMuxMap { for key, existingHandler := range s.dnsMuxMap {
_, found := muxUpdateMap[key] _, found := muxUpdateMap[key]
if !found { if !found {
if !isContainRootUpdate && key == nbdns.RootZone {
s.hostsDnsListLock.Lock()
s.addHostRootZone()
s.hostsDnsListLock.Unlock()
existingHandler.stop() existingHandler.stop()
s.deregisterMux(key) } else {
existingHandler.stop()
s.service.DeregisterMux(key)
}
} }
} }
@ -455,14 +397,6 @@ func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
} }
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate // upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to // the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate. // reactivate it. Not allowed to call reactivate before deactivate.
@ -490,7 +424,7 @@ func (s *DefaultServer) upstreamCallbacks(
for i, item := range s.currentConfig.domains { for i, item := range s.currentConfig.domains {
if _, found := removeIndex[item.domain]; found { if _, found := removeIndex[item.domain]; found {
s.currentConfig.domains[i].disabled = true s.currentConfig.domains[i].disabled = true
s.deregisterMux(item.domain) s.service.DeregisterMux(item.domain)
removeIndex[item.domain] = i removeIndex[item.domain] = i
} }
} }
@ -507,7 +441,7 @@ func (s *DefaultServer) upstreamCallbacks(
continue continue
} }
s.currentConfig.domains[i].disabled = false s.currentConfig.domains[i].disabled = false
s.registerMux(domain, handler) s.service.RegisterMux(domain, handler)
} }
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
@ -523,93 +457,13 @@ func (s *DefaultServer) upstreamCallbacks(
return return
} }
func (s *DefaultServer) filterDNSTraffic() string { func (s *DefaultServer) addHostRootZone() {
filter := s.wgInterface.GetFilter() handler := newUpstreamResolver(s.ctx)
if filter == nil { handler.upstreamServers = make([]string, len(s.hostsDnsList))
log.Error("can't set DNS filter, filter not initialized") for n, ua := range s.hostsDnsList {
return "" handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
} }
handler.deactivate = func() {}
firstLayerDecoder := layers.LayerTypeIPv4 handler.reactivate = func() {}
if s.wgInterface.Address().Network.IP.To4() == nil { s.service.RegisterMux(nbdns.RootZone, handler)
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
}
func (s *DefaultServer) evelRuntimeAddressForUserspace() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = defaultPort
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}
func (s *DefaultServer) evalRuntimeAddress() {
defer func() {
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}()
if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())
return
}
ip, port, err := s.getFirstListenerAvailable()
if err != nil {
log.Error(err)
return
}
s.runtimeIP = ip
s.runtimePort = port
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}
func hasValidDnsServer(cfg *nbdns.Config) bool {
for _, c := range cfg.NameServerGroups {
if c.Primary {
return true
}
}
return false
} }

View File

@ -0,0 +1,29 @@
package dns
import (
"fmt"
"sync"
)
var (
mutex sync.Mutex
server Server
)
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
func GetServerDns() (Server, error) {
mutex.Lock()
if server == nil {
mutex.Unlock()
return nil, fmt.Errorf("DNS server not instantiated yet")
}
s := server
mutex.Unlock()
return s, nil
}
func setServerDns(newServerServer Server) {
mutex.Lock()
server = newServerServer
defer mutex.Unlock()
}

View File

@ -0,0 +1,24 @@
package dns
import (
"testing"
)
func TestGetServerDns(t *testing.T) {
_, err := GetServerDns()
if err == nil {
t.Errorf("invalid dns server instance")
}
srv := &MockServer{}
setServerDns(srv)
srvB, err := GetServerDns()
if err != nil {
t.Errorf("invalid dns server instance: %s", err)
}
if srvB != srv {
t.Errorf("missmatch dns instances")
}
}

View File

@ -11,14 +11,53 @@ import (
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/miekg/dns" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
pfmock "github.com/netbirdio/netbird/iface/mocks" pfmock "github.com/netbirdio/netbird/iface/mocks"
) )
type mocWGIface struct {
filter iface.PacketFilter
}
func (w *mocWGIface) Name() string {
panic("implement me")
}
func (w *mocWGIface) Address() iface.WGAddress {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return iface.WGAddress{
IP: ip,
Network: network,
}
}
func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter
}
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
panic("implement me")
}
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
panic("implement me")
}
func (w *mocWGIface) IsUserspaceBind() bool {
return false
}
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
w.filter = filter
return nil
}
var zoneRecords = []nbdns.SimpleRecord{ var zoneRecords = []nbdns.SimpleRecord{
{ {
Name: "peera.netbird.cloud", Name: "peera.netbird.cloud",
@ -29,6 +68,11 @@ var zoneRecords = []nbdns.SimpleRecord{
}, },
} }
func init() {
log.SetLevel(log.TraceLevel)
formatter.SetTextFormatter(log.StandardLogger())
}
func TestUpdateDNSServer(t *testing.T) { func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
{ {
@ -224,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil) dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -242,8 +286,6 @@ func TestUpdateDNSServer(t *testing.T) {
dnsServer.dnsMuxMap = testCase.initUpstreamMap dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.updateSerial = testCase.initSerial dnsServer.updateSerial = testCase.initSerial
// pretend we are running
dnsServer.listenerIsRunning = true
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil { if err != nil {
@ -282,7 +324,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED") ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov) defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
os.Setenv("NB_WG_KERNEL_DISABLED", "true") _ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil) newNet, err := stdnet.NewNet(nil)
if err != nil { if err != nil {
t.Errorf("create stdnet: %v", err) t.Errorf("create stdnet: %v", err)
@ -316,17 +358,17 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().SetNetwork(ipNet)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes() packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes() packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil { if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err) t.Errorf("set packet filter: %v", err)
return return
} }
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil) dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil { if err != nil {
t.Errorf("create DNS server: %v", err) t.Errorf("create DNS server: %v", err)
return return
@ -421,21 +463,23 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort) dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
if err != nil {
dnsServer.hostManager = newNoopHostMocker() t.Fatalf("%v", err)
dnsServer.listen()
time.Sleep(100 * time.Millisecond)
if !dnsServer.listenerIsRunning {
t.Fatal("dns server listener is not running")
} }
dnsServer.hostManager = newNoopHostMocker()
err = dnsServer.service.Listen()
if err != nil {
t.Fatalf("dns server is not running: %s", err)
}
time.Sleep(100 * time.Millisecond)
defer dnsServer.Stop() defer dnsServer.Stop()
err := dnsServer.localResolver.registerRecord(zoneRecords[0]) err = dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver) dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
resolver := &net.Resolver{ resolver := &net.Resolver{
PreferGo: true, PreferGo: true,
@ -443,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
d := net.Dialer{ d := net.Dialer{
Timeout: time.Second * 5, Timeout: time.Second * 5,
} }
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort) addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
conn, err := d.DialContext(ctx, network, addr) conn, err := d.DialContext(ctx, network, addr)
if err != nil { if err != nil {
t.Log(err) t.Log(err)
@ -478,7 +522,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{} hostManager := &mockHostConfigurator{}
server := DefaultServer{ server := DefaultServer{
dnsMux: dns.DefaultServeMux, service: newServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
@ -541,62 +585,237 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
} }
} }
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer { func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
mux := dns.NewServeMux() wgIFace, err := createWgInterfaceWithBind(t)
var parsedAddrPort *netip.AddrPort
if addrPort != "" {
parsed, err := netip.ParseAddrPort(addrPort)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal("failed to initialize wg interface")
}
parsedAddrPort = &parsed
} }
defer wgIFace.Close()
dnsServer := &dns.Server{ var dnsList []string
Net: "udp", dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
Handler: mux, err = dnsServer.Initialize()
UDPSize: 65535,
}
ctx, cancel := context.WithCancel(context.TODO())
ds := &DefaultServer{
ctx: ctx,
ctxCancel: cancel,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
customAddress: parsedAddrPort,
}
ds.evalRuntimeAddress()
return ds
}
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil { if err != nil {
t.Errorf("Error parsing CIDR: %v", err) t.Errorf("failed to initialize DNS server: %v", err)
return return
} }
defer dnsServer.Stop()
lastIP := getLastIPFromNetwork(ipnet, 1) dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP) resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
} _, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
}
func TestDNSPermanent_updateUpstream(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Enabled: true,
Primary: true,
},
},
}
err = dnsServer.UpdateDNSServer(1, update)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
update2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{},
}
err = dnsServer.UpdateDNSServer(2, update2)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
}
func TestDNSPermanent_matchOnly(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Domains: []string{"customdomain.com"},
Primary: false,
},
},
}
err = dnsServer.UpdateDNSServer(1, update)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, nil
}
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err
}
err = wgIface.Create()
if err != nil {
t.Fatalf("crate and init wireguard interface: %v", err)
return nil, err
}
pf, err := uspfilter.Create(wgIface)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err
}
err = wgIface.SetFilter(pf)
if err != nil {
t.Fatalf("set packet filter: %v", err)
return nil, err
}
return wgIface, nil
}
func newDnsResolver(ip string, port int) *net.Resolver {
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 3,
}
addr := fmt.Sprintf("%s:%d", ip, port)
return d.DialContext(ctx, network, addr)
},
} }
} }

View File

@ -0,0 +1,18 @@
package dns
import (
"github.com/miekg/dns"
)
const (
defaultPort = 53
)
type service interface {
Listen() error
Stop()
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int
RuntimeIP() string
}

View File

@ -0,0 +1,145 @@
package dns
import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
const (
customPort = 5053
defaultIP = "127.0.0.1"
customIP = "127.0.0.153"
)
type serviceViaListener struct {
wgInterface WGIface
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
runtimeIP string
runtimePort int
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
}
return s
}
func (s *serviceViaListener) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
if err != nil {
log.Errorf("failed to eval runtime address: %s", err)
return err
}
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
return nil
}
func (s *serviceViaListener) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
}
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaListener) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaListener) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaListener) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
if s.customAddr != nil {
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
}
return s.getFirstListenerAvailable()
}

View File

@ -0,0 +1,139 @@
package dns
import (
"fmt"
"math/big"
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
type serviceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP string
runtimePort int
udpFilterHookID string
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
s := &serviceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
runtimePort: defaultPort,
}
return s
}
func (s *serviceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return err
}
s.listenerIsRunning = true
log.Debugf("dns service listening on: %s", s.RuntimeIP())
return nil
}
func (s *serviceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
s.listenerIsRunning = false
}
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaMemory) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaMemory) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}

View File

@ -0,0 +1,31 @@
package dns
import (
"net"
"testing"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := getLastIPFromNetwork(ipnet, 1)
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@ -15,7 +15,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
@ -53,7 +52,7 @@ type systemdDbusLinkDomainsInput struct {
MatchOnly bool MatchOnly bool
} }
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
iface, err := net.InterfaceByName(wgInterface.Name()) iface, err := net.InterfaceByName(wgInterface.Name())
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -0,0 +1,14 @@
//go:build !windows
package dns
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper
}

View File

@ -0,0 +1,13 @@
package dns
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper
GetInterfaceGUIDString() (string, error)
}

View File

@ -190,23 +190,25 @@ func (e *Engine) Start() error {
} }
var routes []*route.Route var routes []*route.Route
var dnsCfg *nbdns.Config
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
routes, dnsCfg, err = e.readInitialSettings() routes, err = e.readInitialSettings()
if err != nil { if err != nil {
return err return err
} }
}
if e.dnsServer == nil { if e.dnsServer == nil {
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses)
go e.mobileDep.DnsReadyListener.OnReady()
}
} else {
// todo fix custom address // todo fix custom address
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg) if e.dnsServer == nil {
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
if err != nil { if err != nil {
e.close() e.close()
return err return err
} }
e.dnsServer = dnsServer }
} }
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
@ -1045,14 +1047,13 @@ func (e *Engine) close() {
} }
} }
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { func (e *Engine) readInitialSettings() ([]*route.Route, error) {
netMap, err := e.mgmClient.GetNetworkMap() netMap, err := e.mgmClient.GetNetworkMap()
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
routes := toRoutes(netMap.GetRoutes()) routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig()) return routes, nil
return routes, &dnsCfg, nil
} }
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {

View File

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@ -11,4 +12,6 @@ type MobileDependency struct {
TunAdapter iface.TunAdapter TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
RouteListener routemanager.RouteListener RouteListener routemanager.RouteListener
HostDNSAddresses []string
DnsReadyListener dns.ReadyListener
} }

View File

@ -102,7 +102,7 @@ func (s *Server) Start() error {
} }
go func() { go func() {
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil, nil); err != nil { if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil {
log.Errorf("init connections: %v", err) log.Errorf("init connections: %v", err)
} }
}() }()
@ -391,7 +391,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
} }
go func() { go func() {
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil, nil); err != nil { if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil {
log.Errorf("run client connection: %v", err) log.Errorf("run client connection: %v", err)
return return
} }

View File

@ -5,7 +5,6 @@ import (
"sync" "sync"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
@ -34,7 +33,6 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error { func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
return w.tun.Create(mIFaceArgs) return w.tun.Create(mIFaceArgs)
} }

View File

@ -7,7 +7,6 @@ import (
"sync" "sync"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
@ -38,6 +37,5 @@ func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
func (w *WGIface) Create() error { func (w *WGIface) Create() error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
return w.tun.Create() return w.tun.Create()
} }

View File

@ -34,6 +34,7 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
} }
func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error { func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
log.Info("create tun interface")
var err error var err error
routesString := t.routesToString(mIFaceArgs.Routes) routesString := t.routesToString(mIFaceArgs.Routes)
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString) t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString)

View File

@ -12,14 +12,14 @@ import (
func (c *tunDevice) Create() error { func (c *tunDevice) Create() error {
if WireGuardModuleIsLoaded() { if WireGuardModuleIsLoaded() {
log.Info("using kernel WireGuard") log.Infof("create tun interface with kernel WireGuard support: %s", c.DeviceName())
return c.createWithKernel() return c.createWithKernel()
} }
if !tunModuleIsLoaded() { if !tunModuleIsLoaded() {
return fmt.Errorf("couldn't check or load tun module") return fmt.Errorf("couldn't check or load tun module")
} }
log.Info("using userspace WireGuard") log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName())
var err error var err error
c.netInterface, err = c.createWithUserspace() c.netInterface, err = c.createWithUserspace()
if err != nil { if err != nil {