Feature/permanent dns (#967)

* Add DNS list argument for mobile client

* Write testable code

Many places are checked the wgInterface != nil condition.
It is doing it just because to avoid the real wgInterface creation for tests.
Instead of this involve a wgInterface interface what is moc-able.

* Refactor the DNS server internal code structure

With the fake resolver has been involved several
if-else statement and generated some unused
variables to distinguish the listener and fake
resolver solutions at running time. With this
commit the fake resolver and listener based
solution has been moved into two separated
structure. Name of this layer is the 'service'.
With this modification the unit test looks
simpler and open the option to add new logic for
the permanent DNS service usage for mobile
systems.



* Remove is running check in test

We can not ensure the state well so remove this
check. The test will fail if the server is not
running well.
This commit is contained in:
Zoltan Papp 2023-07-14 21:56:22 +02:00 committed by GitHub
parent 9c2c0e7934
commit 7ebe58f20a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 925 additions and 366 deletions

View File

@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
@ -35,6 +36,11 @@ type RouteListener interface {
routemanager.RouteListener
}
// DnsReadyListener export internal dns ReadyListener for mobile
type DnsReadyListener interface {
dns.ReadyListener
}
func init() {
formatter.SetLogcatFormatter(log.StandardLogger())
}
@ -49,6 +55,7 @@ type Client struct {
ctxCancelLock *sync.Mutex
deviceName string
routeListener routemanager.RouteListener
onHostDnsFn func([]string)
}
// NewClient instantiate a new Client
@ -65,7 +72,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener) error {
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
@ -90,7 +97,8 @@ func (c *Client) Run(urlOpener URLOpener) error {
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener)
c.onHostDnsFn = func([]string) {}
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
}
// Stop the internal client and free the resources
@ -126,6 +134,17 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos}
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns()
if err != nil {
return err
}
dnsServer.OnUpdatedHostDNSServer(list.items)
return nil
}
// SetConnectionListener set the network connection listener
func (c *Client) SetConnectionListener(listener ConnectionListener) {
c.recorder.SetConnectionListener(listener)

View File

@ -0,0 +1,26 @@
package android
import "fmt"
// DNSList is a wrapper of []string
type DNSList struct {
items []string
}
// Add new DNS address to the collection
func (array *DNSList) Add(s string) {
array.items = append(array.items, s)
}
// Get return an element of the collection
func (array *DNSList) Get(i int) (string, error) {
if i >= len(array.items) || i < 0 {
return "", fmt.Errorf("out of range")
}
return array.items[i], nil
}
// Size return with the size of the collection
func (array *DNSList) Size() int {
return len(array.items)
}

View File

@ -0,0 +1,24 @@
package android
import "testing"
func TestDNSList_Get(t *testing.T) {
l := DNSList{
items: make([]string, 1),
}
_, err := l.Get(0)
if err != nil {
t.Errorf("invalid error: %s", err)
}
_, err = l.Get(-1)
if err == nil {
t.Errorf("expected error but got nil")
}
_, err = l.Get(1)
if err == nil {
t.Errorf("expected error but got nil")
}
}

View File

@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {

View File

@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
@ -24,7 +25,24 @@ import (
)
// RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error {
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{})
}
// RunClientMobile with main logic on mobile system
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency)
}
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second,
RandomizationFactor: 1,
@ -151,14 +169,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return wrapErr(err)
}
// in case of non Android os these variables will be nil
md := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
}
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
err = engine.Start()
if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)

View File

@ -1,13 +1,9 @@
package dns
import (
"github.com/netbirdio/netbird/iface"
)
type androidHostManager struct {
}
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
func newHostManager(wgInterface WGIface) (hostManager, error) {
return &androidHostManager{}, nil
}

View File

@ -9,8 +9,6 @@ import (
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
)
const (
@ -34,7 +32,7 @@ type systemConfigurator struct {
createdKeys map[string]struct{}
}
func newHostManager(_ *iface.WGIface) (hostManager, error) {
func newHostManager(_ WGIface) (hostManager, error) {
return &systemConfigurator{
createdKeys: make(map[string]struct{}),
}, nil

View File

@ -5,10 +5,10 @@ package dns
import (
"bufio"
"fmt"
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"os"
"strings"
log "github.com/sirupsen/logrus"
)
const (
@ -25,7 +25,7 @@ const (
type osManagerType int
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
func newHostManager(wgInterface WGIface) (hostManager, error) {
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, err

View File

@ -6,8 +6,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/iface"
)
const (
@ -33,7 +31,7 @@ type registryConfigurator struct {
existingSearchDomains []string
}
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
func newHostManager(wgInterface WGIface) (hostManager, error) {
guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil {
return nil, err

View File

@ -31,6 +31,11 @@ func (m *MockServer) DnsIP() string {
return ""
}
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
//TODO implement me
panic("implement me")
}
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
if m.UpdateDNSServerFunc != nil {

View File

@ -14,8 +14,6 @@ import (
"github.com/hashicorp/go-version"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
)
const (
@ -72,7 +70,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
}
}
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil {
return nil, err

View File

@ -8,8 +8,6 @@ import (
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
)
const resolvconfCommand = "resolvconf"
@ -18,7 +16,7 @@ type resolvconf struct {
ifaceName string
}
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
return &resolvconf{
ifaceName: wgInterface.Name(),
}, nil

View File

@ -3,29 +3,20 @@ package dns
import (
"context"
"fmt"
"math/big"
"net"
"net/netip"
"runtime"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
)
const (
defaultPort = 53
customPort = 5053
defaultIP = "127.0.0.1"
customIP = "127.0.0.153"
)
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
type ReadyListener interface {
OnReady()
}
// Server is a dns server interface
type Server interface {
@ -33,6 +24,7 @@ type Server interface {
Stop()
DnsIP() string
UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string)
}
type registeredHandlerMap map[string]handlerWithStop
@ -42,21 +34,19 @@ type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
mux sync.Mutex
udpFilterHookID string
server *dns.Server
dnsMux *dns.ServeMux
service service
dnsMuxMap registeredHandlerMap
localResolver *localResolver
wgInterface *iface.WGIface
wgInterface WGIface
hostManager hostManager
updateSerial uint64
listenerIsRunning bool
runtimePort int
runtimeIP string
previousConfigHash uint64
currentConfig hostDNSConfig
customAddress *netip.AddrPort
enabled bool
// permanent related properties
permanent bool
hostsDnsList []string
hostsDnsListLock sync.Mutex
}
type handlerWithStop interface {
@ -70,9 +60,7 @@ type muxUpdate struct {
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) {
mux := dns.NewServeMux()
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
var addrPort *netip.AddrPort
if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@ -82,37 +70,44 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
addrPort = &parsedAddrPort
}
ctx, stop := context.WithCancel(ctx)
var dnsService service
if wgInterface.IsUserspaceBind() {
dnsService = newServiceViaMemory(wgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
}
return newDefaultServer(ctx, wgInterface, dnsService), nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone()
setServerDns(ds)
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
ctxCancel: stop,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
dnsMux: mux,
service: dnsService,
dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
wgInterface: wgInterface,
customAddress: addrPort,
}
if initialDnsCfg != nil {
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
}
if wgInterface.IsUserspaceBind() {
defaultServer.evelRuntimeAddressForUserspace()
}
return defaultServer, nil
return defaultServer
}
// Initialize instantiate host manager. It required to be initialized wginterface
// Initialize instantiate host manager and the dns service
func (s *DefaultServer) Initialize() (err error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -121,72 +116,23 @@ func (s *DefaultServer) Initialize() (err error) {
return nil
}
if !s.wgInterface.IsUserspaceBind() {
s.evalRuntimeAddress()
if s.permanent {
err = s.service.Listen()
if err != nil {
return err
}
}
s.hostManager, err = newHostManager(s.wgInterface)
return
}
// listen runs the listener in a go routine
func (s *DefaultServer) listen() {
// nil check required in unit tests
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.udpFilterHookID = s.filterDNSTraffic()
s.setListenerStatus(true)
return
}
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
}
// DnsIP returns the DNS resolver server IP address
//
// When kernel space interface used it return real DNS server listener IP address
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
func (s *DefaultServer) DnsIP() string {
if !s.enabled {
return ""
}
return s.runtimeIP
}
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" && s.wgInterface != nil {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *DefaultServer) setListenerStatus(running bool) {
s.listenerIsRunning = running
return s.service.RuntimeIP()
}
// Stop stops the server
@ -202,37 +148,23 @@ func (s *DefaultServer) Stop() {
}
}
err := s.stopListener()
if err != nil {
log.Error(err)
}
s.service.Stop()
}
func (s *DefaultServer) stopListener() error {
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
// udpFilterHookID here empty only in the unit tests
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" {
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}
s.udpFilterHookID = ""
s.listenerIsRunning = false
return nil
}
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
s.hostsDnsListLock.Lock()
defer s.hostsDnsListLock.Unlock()
if !s.listenerIsRunning {
return nil
s.hostsDnsList = hostsDnsList
_, ok := s.dnsMuxMap[nbdns.RootZone]
if ok {
log.Debugf("on new host DNS config but skip to apply it")
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
}
return nil
log.Debugf("update host DNS settings: %+v", hostsDnsList)
s.addHostRootZone()
}
// UpdateDNSServer processes an update received from the management service
@ -283,12 +215,10 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
if err := s.stopListener(); err != nil {
log.Error(err)
}
} else if !s.listenerIsRunning {
s.listen()
if update.ServiceEnable {
_ = s.service.Listen()
} else if !s.permanent {
s.service.Stop()
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
@ -299,15 +229,14 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
hostUpdate := s.currentConfig
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() {
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
hostUpdate.routeAll = false
@ -412,19 +341,32 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registeredHandlerMap)
var isContainRootUpdate bool
for _, update := range muxUpdates {
s.registerMux(update.domain, update.handler)
s.service.RegisterMux(update.domain, update.handler)
muxUpdateMap[update.domain] = update.handler
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop()
}
if update.domain == nbdns.RootZone {
isContainRootUpdate = true
}
}
for key, existingHandler := range s.dnsMuxMap {
_, found := muxUpdateMap[key]
if !found {
if !isContainRootUpdate && key == nbdns.RootZone {
s.hostsDnsListLock.Lock()
s.addHostRootZone()
s.hostsDnsListLock.Unlock()
existingHandler.stop()
s.deregisterMux(key)
} else {
existingHandler.stop()
s.service.DeregisterMux(key)
}
}
}
@ -455,14 +397,6 @@ func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
}
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *DefaultServer) deregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate.
@ -490,7 +424,7 @@ func (s *DefaultServer) upstreamCallbacks(
for i, item := range s.currentConfig.domains {
if _, found := removeIndex[item.domain]; found {
s.currentConfig.domains[i].disabled = true
s.deregisterMux(item.domain)
s.service.DeregisterMux(item.domain)
removeIndex[item.domain] = i
}
}
@ -507,7 +441,7 @@ func (s *DefaultServer) upstreamCallbacks(
continue
}
s.currentConfig.domains[i].disabled = false
s.registerMux(domain, handler)
s.service.RegisterMux(domain, handler)
}
l := log.WithField("nameservers", nsGroup.NameServers)
@ -523,93 +457,13 @@ func (s *DefaultServer) upstreamCallbacks(
return
}
func (s *DefaultServer) filterDNSTraffic() string {
filter := s.wgInterface.GetFilter()
if filter == nil {
log.Error("can't set DNS filter, filter not initialized")
return ""
func (s *DefaultServer) addHostRootZone() {
handler := newUpstreamResolver(s.ctx)
handler.upstreamServers = make([]string, len(s.hostsDnsList))
for n, ua := range s.hostsDnsList {
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
}
func (s *DefaultServer) evelRuntimeAddressForUserspace() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = defaultPort
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}
func (s *DefaultServer) evalRuntimeAddress() {
defer func() {
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
}()
if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())
return
}
ip, port, err := s.getFirstListenerAvailable()
if err != nil {
log.Error(err)
return
}
s.runtimeIP = ip
s.runtimePort = port
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}
func hasValidDnsServer(cfg *nbdns.Config) bool {
for _, c := range cfg.NameServerGroups {
if c.Primary {
return true
}
}
return false
handler.deactivate = func() {}
handler.reactivate = func() {}
s.service.RegisterMux(nbdns.RootZone, handler)
}

View File

@ -0,0 +1,29 @@
package dns
import (
"fmt"
"sync"
)
var (
mutex sync.Mutex
server Server
)
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
func GetServerDns() (Server, error) {
mutex.Lock()
if server == nil {
mutex.Unlock()
return nil, fmt.Errorf("DNS server not instantiated yet")
}
s := server
mutex.Unlock()
return s, nil
}
func setServerDns(newServerServer Server) {
mutex.Lock()
server = newServerServer
defer mutex.Unlock()
}

View File

@ -0,0 +1,24 @@
package dns
import (
"testing"
)
func TestGetServerDns(t *testing.T) {
_, err := GetServerDns()
if err == nil {
t.Errorf("invalid dns server instance")
}
srv := &MockServer{}
setServerDns(srv)
srvB, err := GetServerDns()
if err != nil {
t.Errorf("invalid dns server instance: %s", err)
}
if srvB != srv {
t.Errorf("missmatch dns instances")
}
}

View File

@ -11,14 +11,53 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface"
pfmock "github.com/netbirdio/netbird/iface/mocks"
)
type mocWGIface struct {
filter iface.PacketFilter
}
func (w *mocWGIface) Name() string {
panic("implement me")
}
func (w *mocWGIface) Address() iface.WGAddress {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return iface.WGAddress{
IP: ip,
Network: network,
}
}
func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter
}
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
panic("implement me")
}
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
panic("implement me")
}
func (w *mocWGIface) IsUserspaceBind() bool {
return false
}
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
w.filter = filter
return nil
}
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
@ -29,6 +68,11 @@ var zoneRecords = []nbdns.SimpleRecord{
},
}
func init() {
log.SetLevel(log.TraceLevel)
formatter.SetTextFormatter(log.StandardLogger())
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
@ -224,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil {
t.Fatal(err)
}
@ -242,8 +286,6 @@ func TestUpdateDNSServer(t *testing.T) {
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap
dnsServer.updateSerial = testCase.initSerial
// pretend we are running
dnsServer.listenerIsRunning = true
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
@ -282,7 +324,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
os.Setenv("NB_WG_KERNEL_DISABLED", "true")
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Errorf("create stdnet: %v", err)
@ -316,17 +358,17 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().SetNetwork(ipNet)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
if err != nil {
t.Errorf("create DNS server: %v", err)
return
@ -421,21 +463,23 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
dnsServer.hostManager = newNoopHostMocker()
dnsServer.listen()
time.Sleep(100 * time.Millisecond)
if !dnsServer.listenerIsRunning {
t.Fatal("dns server listener is not running")
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
if err != nil {
t.Fatalf("%v", err)
}
dnsServer.hostManager = newNoopHostMocker()
err = dnsServer.service.Listen()
if err != nil {
t.Fatalf("dns server is not running: %s", err)
}
time.Sleep(100 * time.Millisecond)
defer dnsServer.Stop()
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
resolver := &net.Resolver{
PreferGo: true,
@ -443,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
d := net.Dialer{
Timeout: time.Second * 5,
}
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
t.Log(err)
@ -478,7 +522,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
dnsMux: dns.DefaultServeMux,
service: newServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
@ -541,62 +585,237 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
}
}
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
mux := dns.NewServeMux()
var parsedAddrPort *netip.AddrPort
if addrPort != "" {
parsed, err := netip.ParseAddrPort(addrPort)
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal(err)
}
parsedAddrPort = &parsed
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsServer := &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
}
ctx, cancel := context.WithCancel(context.TODO())
ds := &DefaultServer{
ctx: ctx,
ctxCancel: cancel,
server: dnsServer,
dnsMux: mux,
dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
customAddress: parsedAddrPort,
}
ds.evalRuntimeAddress()
return ds
}
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
var dnsList []string
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
lastIP := getLastIPFromNetwork(ipnet, 1)
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
}
func TestDNSPermanent_updateUpstream(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Enabled: true,
Primary: true,
},
},
}
err = dnsServer.UpdateDNSServer(1, update)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
update2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{},
}
err = dnsServer.UpdateDNSServer(2, update2)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
}
func TestDNSPermanent_matchOnly(t *testing.T) {
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
}
defer wgIFace.Close()
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
return
}
defer dnsServer.Stop()
// check initial state
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
Domains: []string{"customdomain.com"},
Primary: false,
},
},
}
err = dnsServer.UpdateDNSServer(1, update)
if err != nil {
t.Errorf("failed to update dns server: %s", err)
}
_, err = resolver.LookupHost(context.Background(), "netbird.io")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
if err != nil {
t.Fatalf("failed resolve zone record: %v", err)
}
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(nil)
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, nil
}
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err
}
err = wgIface.Create()
if err != nil {
t.Fatalf("crate and init wireguard interface: %v", err)
return nil, err
}
pf, err := uspfilter.Create(wgIface)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err
}
err = wgIface.SetFilter(pf)
if err != nil {
t.Fatalf("set packet filter: %v", err)
return nil, err
}
return wgIface, nil
}
func newDnsResolver(ip string, port int) *net.Resolver {
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Second * 3,
}
addr := fmt.Sprintf("%s:%d", ip, port)
return d.DialContext(ctx, network, addr)
},
}
}

View File

@ -0,0 +1,18 @@
package dns
import (
"github.com/miekg/dns"
)
const (
defaultPort = 53
)
type service interface {
Listen() error
Stop()
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int
RuntimeIP() string
}

View File

@ -0,0 +1,145 @@
package dns
import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
const (
customPort = 5053
defaultIP = "127.0.0.1"
customIP = "127.0.0.153"
)
type serviceViaListener struct {
wgInterface WGIface
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
runtimeIP string
runtimePort int
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
}
return s
}
func (s *serviceViaListener) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
if err != nil {
log.Errorf("failed to eval runtime address: %s", err)
return err
}
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
}
}()
return nil
}
func (s *serviceViaListener) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
}
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaListener) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaListener) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaListener) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerIsRunning = running
}
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
ips := []string{defaultIP, customIP}
if runtime.GOOS != "darwin" {
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
}
ports := []int{defaultPort, customPort}
for _, port := range ports {
for _, ip := range ips {
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err == nil {
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
}
return ip, port, nil
}
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
}
}
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
}
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
if s.customAddr != nil {
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
}
return s.getFirstListenerAvailable()
}

View File

@ -0,0 +1,139 @@
package dns
import (
"fmt"
"math/big"
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
type serviceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP string
runtimePort int
udpFilterHookID string
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
s := &serviceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
runtimePort: defaultPort,
}
return s
}
func (s *serviceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if s.listenerIsRunning {
return nil
}
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return err
}
s.listenerIsRunning = true
log.Debugf("dns service listening on: %s", s.RuntimeIP())
return nil
}
func (s *serviceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
}
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
s.listenerIsRunning = false
}
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *serviceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *serviceViaMemory) RuntimePort() int {
return s.runtimePort
}
func (s *serviceViaMemory) RuntimeIP() string {
return s.runtimeIP
}
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}

View File

@ -0,0 +1,31 @@
package dns
import (
"net"
"testing"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := getLastIPFromNetwork(ipnet, 1)
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@ -15,7 +15,6 @@ import (
"golang.org/x/sys/unix"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
)
const (
@ -53,7 +52,7 @@ type systemdDbusLinkDomainsInput struct {
MatchOnly bool
}
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
iface, err := net.InterfaceByName(wgInterface.Name())
if err != nil {
return nil, err

View File

@ -0,0 +1,14 @@
//go:build !windows
package dns
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper
}

View File

@ -0,0 +1,13 @@
package dns
import "github.com/netbirdio/netbird/iface"
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper
GetInterfaceGUIDString() (string, error)
}

View File

@ -190,23 +190,25 @@ func (e *Engine) Start() error {
}
var routes []*route.Route
var dnsCfg *nbdns.Config
if runtime.GOOS == "android" {
routes, dnsCfg, err = e.readInitialSettings()
routes, err = e.readInitialSettings()
if err != nil {
return err
}
}
if e.dnsServer == nil {
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses)
go e.mobileDep.DnsReadyListener.OnReady()
}
} else {
// todo fix custom address
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg)
if e.dnsServer == nil {
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
if err != nil {
e.close()
return err
}
e.dnsServer = dnsServer
}
}
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
@ -1045,14 +1047,13 @@ func (e *Engine) close() {
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
func (e *Engine) readInitialSettings() ([]*route.Route, error) {
netMap, err := e.mgmClient.GetNetworkMap()
if err != nil {
return nil, nil, err
return nil, err
}
routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
return routes, &dnsCfg, nil
return routes, nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {

View File

@ -1,6 +1,7 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
@ -11,4 +12,6 @@ type MobileDependency struct {
TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
RouteListener routemanager.RouteListener
HostDNSAddresses []string
DnsReadyListener dns.ReadyListener
}

View File

@ -102,7 +102,7 @@ func (s *Server) Start() error {
}
go func() {
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil, nil); err != nil {
if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil {
log.Errorf("init connections: %v", err)
}
}()
@ -391,7 +391,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
}
go func() {
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil, nil); err != nil {
if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil {
log.Errorf("run client connection: %v", err)
return
}

View File

@ -5,7 +5,6 @@ import (
"sync"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
)
// NewWGIFace Creates a new WireGuard interface instance
@ -34,7 +33,6 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
return w.tun.Create(mIFaceArgs)
}

View File

@ -7,7 +7,6 @@ import (
"sync"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
)
// NewWGIFace Creates a new WireGuard interface instance
@ -38,6 +37,5 @@ func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
return w.tun.Create()
}

View File

@ -34,6 +34,7 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
}
func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
log.Info("create tun interface")
var err error
routesString := t.routesToString(mIFaceArgs.Routes)
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString)

View File

@ -12,14 +12,14 @@ import (
func (c *tunDevice) Create() error {
if WireGuardModuleIsLoaded() {
log.Info("using kernel WireGuard")
log.Infof("create tun interface with kernel WireGuard support: %s", c.DeviceName())
return c.createWithKernel()
}
if !tunModuleIsLoaded() {
return fmt.Errorf("couldn't check or load tun module")
}
log.Info("using userspace WireGuard")
log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName())
var err error
c.netInterface, err = c.createWithUserspace()
if err != nil {