Feature/add iOS support (#1244)

* starting engine by passing file descriptor on engine start

* inject logger that does not compile

* logger and first client

* first working connection

* support for routes and working connection

* small refactor for better code quality in swift

* trying to add DNS

* fix

* updated

* fix route deletion

* trying to bind the DNS resolver dialer to an interface

* use dns.Client.Exchange

* fix metadata send on startup

* switching between client to query upstream

* fix panic on no dns response

* fix after merge changes

* add engine ready listener

* replace engine listener with connection listener

* disable relay connection for iOS until proxy is refactored into bind

* Extract private upstream for iOS and fix function headers for other OS

* Update mock Server

* Fix dns server and upstream tests

* Fix engine null pointer with mobile dependencies for other OS

* Revert back to disabling upstream on no response

* Fix some of the remarks from the linter

* Fix linter

* re-arrange duration calculation

* revert exported HostDNSConfig

* remove unused engine listener

* remove development logs

* refactor dns code and interface name propagation

* clean dns server test

* disable upstream deactivation for iOS

* remove files after merge

* fix dns server darwin

* fix server mock

* fix build flags

* move service listen back to initialize

* add wgInterface to hostManager initialization on android

* fix typo and remove unused function

* extract upstream exchange for ios and rest

* remove todo

* separate upstream logic to ios file

* Fix upstream test

* use interface and embedded struct for upstream

* set properly upstream client

* remove placeholder

* remove ios specific attributes

* fix upstream test

* merge ipc parser and wg configurer for mobile

* fix build annotation

* use json for DNS settings handover through gomobile

* add logs for DNS json string

* bring back check on ios for private upstream

* remove wrong (and unused) line

* fix wrongly updated comments on DNSSetting export

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
This commit is contained in:
pascal-fischer 2023-12-18 11:46:58 +01:00 committed by GitHub
parent 01f28baec7
commit 818c6b885f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
52 changed files with 1377 additions and 188 deletions

View File

@ -43,6 +43,15 @@ func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.S
return runClient(ctx, config, statusRecorder, mobileDependency) return runClient(ctx, config, statusRecorder, mobileDependency)
} }
func RunClientiOS(ctx context.Context, config *Config, statusRecorder *peer.Status, fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager) error {
mobileDependency := MobileDependency{
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
return runClient(ctx, config, statusRecorder, mobileDependency)
}
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error { func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
log.Infof("starting NetBird client version %s", version.NetbirdVersion()) log.Infof("starting NetBird client version %s", version.NetbirdVersion())

View File

@ -35,14 +35,14 @@ func (f *fileConfigurator) supportCustomPort() bool {
return false return false
} }
func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error { func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
backupFileExist := false backupFileExist := false
_, err := os.Stat(fileDefaultResolvConfBackupLocation) _, err := os.Stat(fileDefaultResolvConfBackupLocation)
if err == nil { if err == nil {
backupFileExist = true backupFileExist = true
} }
if !config.routeAll { if !config.RouteAll {
if backupFileExist { if backupFileExist {
err = f.restore() err = f.restore()
if err != nil { if err != nil {
@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
buf := prepareResolvConfContent( buf := prepareResolvConfContent(
searchDomainList, searchDomainList,
append([]string{config.serverIP}, nameServers...), append([]string{config.ServerIP}, nameServers...),
others) others)
log.Debugf("creating managed file %s", defaultResolvConfPath) log.Debugf("creating managed file %s", defaultResolvConfPath)
@ -138,14 +138,14 @@ func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes
return buf return buf
} }
func searchDomains(config hostDNSConfig) []string { func searchDomains(config HostDNSConfig) []string {
listOfDomains := make([]string, 0) listOfDomains := make([]string, 0)
for _, dConf := range config.domains { for _, dConf := range config.Domains {
if dConf.matchOnly || dConf.disabled { if dConf.MatchOnly || dConf.Disabled {
continue continue
} }
listOfDomains = append(listOfDomains, dConf.domain) listOfDomains = append(listOfDomains, dConf.Domain)
} }
return listOfDomains return listOfDomains
} }
@ -214,7 +214,7 @@ func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, othe
return return
} }
// merge search domains lists and cut off the list if it is too long // merge search Domains lists and cut off the list if it is too long
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string { func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
lineSize := len("search") lineSize := len("search")
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains)) searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))
@ -225,14 +225,14 @@ func mergeSearchDomains(searchDomains []string, originalSearchDomains []string)
return searchDomainsList return searchDomainsList
} }
// validateAndFillSearchDomains checks if the search domains list is not too long and if the line is not too long // validateAndFillSearchDomains checks if the search Domains list is not too long and if the line is not too long
// extend s slice with vs elements // extend s slice with vs elements
// return with the number of characters in the searchDomains line // return with the number of characters in the searchDomains line
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int { func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
for _, sd := range vs { for _, sd := range vs {
tmpCharsNumber := initialLineChars + 1 + len(sd) tmpCharsNumber := initialLineChars + 1 + len(sd)
if tmpCharsNumber > fileMaxLineCharsLimit { if tmpCharsNumber > fileMaxLineCharsLimit {
// lets log all skipped domains // lets log all skipped Domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd) log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
continue continue
} }
@ -240,7 +240,7 @@ func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string
initialLineChars = tmpCharsNumber initialLineChars = tmpCharsNumber
if len(*s) >= fileMaxNumberOfSearchDomains { if len(*s) >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains // lets log all skipped Domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd) log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
continue continue
} }

View File

@ -8,31 +8,31 @@ import (
) )
type hostManager interface { type hostManager interface {
applyDNSConfig(config hostDNSConfig) error applyDNSConfig(config HostDNSConfig) error
restoreHostDNS() error restoreHostDNS() error
supportCustomPort() bool supportCustomPort() bool
} }
type hostDNSConfig struct { type HostDNSConfig struct {
domains []domainConfig Domains []DomainConfig `json:"domains"`
routeAll bool RouteAll bool `json:"routeAll"`
serverIP string ServerIP string `json:"serverIP"`
serverPort int ServerPort int `json:"serverPort"`
} }
type domainConfig struct { type DomainConfig struct {
disabled bool Disabled bool `json:"disabled"`
domain string Domain string `json:"domain"`
matchOnly bool MatchOnly bool `json:"matchOnly"`
} }
type mockHostConfigurator struct { type mockHostConfigurator struct {
applyDNSConfigFunc func(config hostDNSConfig) error applyDNSConfigFunc func(config HostDNSConfig) error
restoreHostDNSFunc func() error restoreHostDNSFunc func() error
supportCustomPortFunc func() bool supportCustomPortFunc func() bool
} }
func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error { func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
if m.applyDNSConfigFunc != nil { if m.applyDNSConfigFunc != nil {
return m.applyDNSConfigFunc(config) return m.applyDNSConfigFunc(config)
} }
@ -55,38 +55,38 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
func newNoopHostMocker() hostManager { func newNoopHostMocker() hostManager {
return &mockHostConfigurator{ return &mockHostConfigurator{
applyDNSConfigFunc: func(config hostDNSConfig) error { return nil }, applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
restoreHostDNSFunc: func() error { return nil }, restoreHostDNSFunc: func() error { return nil },
supportCustomPortFunc: func() bool { return true }, supportCustomPortFunc: func() bool { return true },
} }
} }
func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig { func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostDNSConfig {
config := hostDNSConfig{ config := HostDNSConfig{
routeAll: false, RouteAll: false,
serverIP: ip, ServerIP: ip,
serverPort: port, ServerPort: port,
} }
for _, nsConfig := range dnsConfig.NameServerGroups { for _, nsConfig := range dnsConfig.NameServerGroups {
if len(nsConfig.NameServers) == 0 { if len(nsConfig.NameServers) == 0 {
continue continue
} }
if nsConfig.Primary { if nsConfig.Primary {
config.routeAll = true config.RouteAll = true
} }
for _, domain := range nsConfig.Domains { for _, domain := range nsConfig.Domains {
config.domains = append(config.domains, domainConfig{ config.Domains = append(config.Domains, DomainConfig{
domain: strings.TrimSuffix(domain, "."), Domain: strings.TrimSuffix(domain, "."),
matchOnly: !nsConfig.SearchDomainsEnabled, MatchOnly: !nsConfig.SearchDomainsEnabled,
}) })
} }
} }
for _, customZone := range dnsConfig.CustomZones { for _, customZone := range dnsConfig.CustomZones {
config.domains = append(config.domains, domainConfig{ config.Domains = append(config.Domains, DomainConfig{
domain: strings.TrimSuffix(customZone.Domain, "."), Domain: strings.TrimSuffix(customZone.Domain, "."),
matchOnly: false, MatchOnly: false,
}) })
} }

View File

@ -7,7 +7,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
return &androidHostManager{}, nil return &androidHostManager{}, nil
} }
func (a androidHostManager) applyDNSConfig(config hostDNSConfig) error { func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
return nil return nil
} }

View File

@ -1,3 +1,5 @@
//go:build !ios
package dns package dns
import ( import (
@ -42,11 +44,11 @@ func (s *systemConfigurator) supportCustomPort() bool {
return true return true
} }
func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error var err error
if config.routeAll { if config.RouteAll {
err = s.addDNSSetupForAll(config.serverIP, config.serverPort) err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
if err != nil { if err != nil {
return err return err
} }
@ -56,7 +58,7 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
return err return err
} }
s.primaryServiceID = "" s.primaryServiceID = ""
log.Infof("removed %s:%d as main DNS resolver for this peer", config.serverIP, config.serverPort) log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
} }
var ( var (
@ -64,20 +66,20 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
matchDomains []string matchDomains []string
) )
for _, dConf := range config.domains { for _, dConf := range config.Domains {
if dConf.disabled { if dConf.Disabled {
continue continue
} }
if dConf.matchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.domain) matchDomains = append(matchDomains, dConf.Domain)
continue continue
} }
searchDomains = append(searchDomains, dConf.domain) searchDomains = append(searchDomains, dConf.Domain)
} }
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.serverIP, config.serverPort) err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else { } else {
log.Infof("removing match domains from the system") log.Infof("removing match domains from the system")
err = s.removeKeyFromSystemConfig(matchKey) err = s.removeKeyFromSystemConfig(matchKey)
@ -88,7 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error {
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 { if len(searchDomains) != 0 {
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.serverIP, config.serverPort) err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
} else { } else {
log.Infof("removing search domains from the system") log.Infof("removing search domains from the system")
err = s.removeKeyFromSystemConfig(searchKey) err = s.removeKeyFromSystemConfig(searchKey)

View File

@ -0,0 +1,37 @@
package dns
import (
"encoding/json"
log "github.com/sirupsen/logrus"
)
type iosHostManager struct {
dnsManager IosDnsManager
config HostDNSConfig
}
func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
return &iosHostManager{
dnsManager: dnsManager,
}, nil
}
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
jsonData, err := json.Marshal(config)
if err != nil {
return err
}
jsonString := string(jsonData)
log.Debugf("Applying DNS settings: %s", jsonString)
a.dnsManager.ApplyDns(jsonString)
return nil
}
func (a iosHostManager) restoreHostDNS() error {
return nil
}
func (a iosHostManager) supportCustomPort() bool {
return false
}

View File

@ -43,10 +43,10 @@ func (s *registryConfigurator) supportCustomPort() bool {
return false return false
} }
func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error { func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error var err error
if config.routeAll { if config.RouteAll {
err = r.addDNSSetupForAll(config.serverIP) err = r.addDNSSetupForAll(config.ServerIP)
if err != nil { if err != nil {
return err return err
} }
@ -56,7 +56,7 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
return err return err
} }
r.routingAll = false r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.serverIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
var ( var (
@ -64,18 +64,18 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error {
matchDomains []string matchDomains []string
) )
for _, dConf := range config.domains { for _, dConf := range config.Domains {
if dConf.disabled { if dConf.Disabled {
continue continue
} }
if !dConf.matchOnly { if !dConf.MatchOnly {
searchDomains = append(searchDomains, dConf.domain) searchDomains = append(searchDomains, dConf.Domain)
} }
matchDomains = append(matchDomains, "."+dConf.domain) matchDomains = append(matchDomains, "."+dConf.Domain)
} }
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
err = r.addDNSMatchPolicy(matchDomains, config.serverIP) err = r.addDNSMatchPolicy(matchDomains, config.ServerIP)
} else { } else {
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath) err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
} }

View File

@ -1,10 +1,12 @@
package dns package dns
import ( import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"strings" "strings"
"testing" "testing"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
) )
func TestLocalResolver_ServeDNS(t *testing.T) { func TestLocalResolver_ServeDNS(t *testing.T) {

View File

@ -33,7 +33,7 @@ func (m *MockServer) DnsIP() string {
} }
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) { func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
//TODO implement me // TODO implement me
panic("implement me") panic("implement me")
} }

View File

@ -12,8 +12,9 @@ import (
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/miekg/dns" "github.com/miekg/dns"
nbversion "github.com/netbirdio/netbird/version"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbversion "github.com/netbirdio/netbird/version"
) )
const ( const (
@ -93,7 +94,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
return false return false
} }
func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
connSettings, configVersion, err := n.getAppliedConnectionSettings() connSettings, configVersion, err := n.getAppliedConnectionSettings()
if err != nil { if err != nil {
return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err) return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err)
@ -101,7 +102,7 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
connSettings.cleanDeprecatedSettings() connSettings.cleanDeprecatedSettings()
dnsIP, err := netip.ParseAddr(config.serverIP) dnsIP, err := netip.ParseAddr(config.ServerIP)
if err != nil { if err != nil {
return fmt.Errorf("unable to parse ip address, error: %s", err) return fmt.Errorf("unable to parse ip address, error: %s", err)
} }
@ -111,33 +112,33 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er
searchDomains []string searchDomains []string
matchDomains []string matchDomains []string
) )
for _, dConf := range config.domains { for _, dConf := range config.Domains {
if dConf.disabled { if dConf.Disabled {
continue continue
} }
if dConf.matchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain)) matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
continue continue
} }
searchDomains = append(searchDomains, dns.Fqdn(dConf.domain)) searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
} }
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
priority := networkManagerDbusSearchDomainOnlyPriority priority := networkManagerDbusSearchDomainOnlyPriority
switch { switch {
case config.routeAll: case config.RouteAll:
priority = networkManagerDbusPrimaryDNSPriority priority = networkManagerDbusPrimaryDNSPriority
newDomainList = append(newDomainList, "~.") newDomainList = append(newDomainList, "~.")
if !n.routingAll { if !n.routingAll {
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
} }
case len(matchDomains) > 0: case len(matchDomains) > 0:
priority = networkManagerDbusWithMatchDomainPriority priority = networkManagerDbusWithMatchDomainPriority
} }
if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll { if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
n.routingAll = false n.routingAll = false
} }

View File

@ -52,6 +52,6 @@ func (n *notifier) notify() {
} }
go func(l listener.NetworkChangeListener) { go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged() l.OnNetworkChanged("")
}(n.listener) }(n.listener)
} }

View File

@ -39,9 +39,9 @@ func (r *resolvconf) supportCustomPort() bool {
return false return false
} }
func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error { func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
var err error var err error
if !config.routeAll { if !config.RouteAll {
err = r.restoreHostDNS() err = r.restoreHostDNS()
if err != nil { if err != nil {
log.Error(err) log.Error(err)
@ -54,7 +54,7 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error {
buf := prepareResolvConfContent( buf := prepareResolvConfContent(
searchDomainList, searchDomainList,
append([]string{config.serverIP}, r.originalNameServers...), append([]string{config.ServerIP}, r.originalNameServers...),
r.othersConfigs) r.othersConfigs)
err = r.applyConfig(buf) err = r.applyConfig(buf)

View File

@ -19,6 +19,11 @@ type ReadyListener interface {
OnReady() OnReady()
} }
// IosDnsManager is a dns manager interface for iOS
type IosDnsManager interface {
ApplyDns(string)
}
// Server is a dns server interface // Server is a dns server interface
type Server interface { type Server interface {
Initialize() error Initialize() error
@ -43,7 +48,7 @@ type DefaultServer struct {
hostManager hostManager hostManager hostManager
updateSerial uint64 updateSerial uint64
previousConfigHash uint64 previousConfigHash uint64
currentConfig hostDNSConfig currentConfig HostDNSConfig
// permanent related properties // permanent related properties
permanent bool permanent bool
@ -52,6 +57,7 @@ type DefaultServer struct {
// make sense on mobile only // make sense on mobile only
searchDomainNotifier *notifier searchDomainNotifier *notifier
iosDnsManager IosDnsManager
} }
type handlerWithStop interface { type handlerWithStop interface {
@ -99,6 +105,13 @@ func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface,
return ds return ds
} }
// NewDefaultServerIos returns a new dns server. It optimized for ios
func NewDefaultServerIos(ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
ds.iosDnsManager = iosDnsManager
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer { func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
@ -131,8 +144,8 @@ func (s *DefaultServer) Initialize() (err error) {
} }
} }
s.hostManager, err = newHostManager(s.wgInterface) s.hostManager, err = s.initialize()
return return err
} }
// DnsIP returns the DNS resolver server IP address // DnsIP returns the DNS resolver server IP address
@ -223,20 +236,20 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
func (s *DefaultServer) SearchDomains() []string { func (s *DefaultServer) SearchDomains() []string {
var searchDomains []string var searchDomains []string
for _, dConf := range s.currentConfig.domains { for _, dConf := range s.currentConfig.Domains {
if dConf.disabled { if dConf.Disabled {
continue continue
} }
if dConf.matchOnly { if dConf.MatchOnly {
continue continue
} }
searchDomains = append(searchDomains, dConf.domain) searchDomains = append(searchDomains, dConf.Domain)
} }
return searchDomains return searchDomains
} }
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener or fake resolver // is the service should be Disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records // and proceed with a regular update to clean up the handlers and records
if update.ServiceEnable { if update.ServiceEnable {
_ = s.service.Listen() _ = s.service.Listen()
@ -262,7 +275,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver") "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
hostUpdate.routeAll = false hostUpdate.RouteAll = false
} }
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil { if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
@ -312,7 +325,10 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
continue continue
} }
handler := newUpstreamResolver(s.ctx) handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
if err != nil {
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
}
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType { if ns.NSType != nbdns.UDPNameServerType {
log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", log.Warnf("skipping nameserver %s with type %s, this peer supports only %s",
@ -445,14 +461,14 @@ func (s *DefaultServer) upstreamCallbacks(
} }
if nsGroup.Primary { if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1 removeIndex[nbdns.RootZone] = -1
s.currentConfig.routeAll = false s.currentConfig.RouteAll = false
} }
for i, item := range s.currentConfig.domains { for i, item := range s.currentConfig.Domains {
if _, found := removeIndex[item.domain]; found { if _, found := removeIndex[item.Domain]; found {
s.currentConfig.domains[i].disabled = true s.currentConfig.Domains[i].Disabled = true
s.service.DeregisterMux(item.domain) s.service.DeregisterMux(item.Domain)
removeIndex[item.domain] = i removeIndex[item.Domain] = i
} }
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
@ -464,28 +480,32 @@ func (s *DefaultServer) upstreamCallbacks(
defer s.mux.Unlock() defer s.mux.Unlock()
for domain, i := range removeIndex { for domain, i := range removeIndex {
if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain { if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain {
continue continue
} }
s.currentConfig.domains[i].disabled = false s.currentConfig.Domains[i].Disabled = false
s.service.RegisterMux(domain, handler) s.service.RegisterMux(domain, handler)
} }
l := log.WithField("nameservers", nsGroup.NameServers) l := log.WithField("nameservers", nsGroup.NameServers)
l.Debug("reactivate temporary disabled nameserver group") l.Debug("reactivate temporary Disabled nameserver group")
if nsGroup.Primary { if nsGroup.Primary {
s.currentConfig.routeAll = true s.currentConfig.RouteAll = true
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") l.WithError(err).Error("reactivate temporary Disabled nameserver group, DNS update apply")
} }
} }
return return
} }
func (s *DefaultServer) addHostRootZone() { func (s *DefaultServer) addHostRootZone() {
handler := newUpstreamResolver(s.ctx) handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network)
if err != nil {
log.Errorf("unable to create a new upstream resolver, error: %v", err)
return
}
handler.upstreamServers = make([]string, len(s.hostsDnsList)) handler.upstreamServers = make([]string, len(s.hostsDnsList))
for n, ua := range s.hostsDnsList { for n, ua := range s.hostsDnsList {
a, err := netip.ParseAddr(ua) a, err := netip.ParseAddr(ua)

View File

@ -0,0 +1,5 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface)
}

View File

@ -0,0 +1,7 @@
//go:build !ios
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface)
}

View File

@ -0,0 +1,5 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.iosDnsManager)
}

View File

@ -0,0 +1,7 @@
//go:build !android
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface)
}

View File

@ -527,8 +527,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
}, },
hostManager: hostManager, hostManager: hostManager,
currentConfig: hostDNSConfig{ currentConfig: HostDNSConfig{
domains: []domainConfig{ Domains: []DomainConfig{
{false, "domain0", false}, {false, "domain0", false},
{false, "domain1", false}, {false, "domain1", false},
{false, "domain2", false}, {false, "domain2", false},
@ -537,13 +537,13 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
} }
var domainsUpdate string var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error { hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error {
domains := []string{} domains := []string{}
for _, item := range config.domains { for _, item := range config.Domains {
if item.disabled { if item.Disabled {
continue continue
} }
domains = append(domains, item.domain) domains = append(domains, item.Domain)
} }
domainsUpdate = strings.Join(domains, ",") domainsUpdate = strings.Join(domains, ",")
return nil return nil
@ -559,11 +559,11 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
deactivate() deactivate()
expected := "domain0,domain2" expected := "domain0,domain2"
domains := []string{} domains := []string{}
for _, item := range server.currentConfig.domains { for _, item := range server.currentConfig.Domains {
if item.disabled { if item.Disabled {
continue continue
} }
domains = append(domains, item.domain) domains = append(domains, item.Domain)
} }
got := strings.Join(domains, ",") got := strings.Join(domains, ",")
if expected != got { if expected != got {
@ -573,11 +573,11 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
reactivate() reactivate()
expected = "domain0,domain1,domain2" expected = "domain0,domain1,domain2"
domains = []string{} domains = []string{}
for _, item := range server.currentConfig.domains { for _, item := range server.currentConfig.Domains {
if item.disabled { if item.Disabled {
continue continue
} }
domains = append(domains, item.domain) domains = append(domains, item.Domain)
} }
got = strings.Join(domains, ",") got = strings.Join(domains, ",")
if expected != got { if expected != got {

View File

@ -0,0 +1,5 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface)
}

View File

@ -81,8 +81,8 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
return true return true
} }
func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
parsedIP, err := netip.ParseAddr(config.serverIP) parsedIP, err := netip.ParseAddr(config.ServerIP)
if err != nil { if err != nil {
return fmt.Errorf("unable to parse ip address, error: %s", err) return fmt.Errorf("unable to parse ip address, error: %s", err)
} }
@ -93,7 +93,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
} }
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}) err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
if err != nil { if err != nil {
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.serverIP, config.serverPort, err) return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.ServerIP, config.ServerPort, err)
} }
var ( var (
@ -101,24 +101,24 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
matchDomains []string matchDomains []string
domainsInput []systemdDbusLinkDomainsInput domainsInput []systemdDbusLinkDomainsInput
) )
for _, dConf := range config.domains { for _, dConf := range config.Domains {
if dConf.disabled { if dConf.Disabled {
continue continue
} }
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: dns.Fqdn(dConf.domain), Domain: dns.Fqdn(dConf.Domain),
MatchOnly: dConf.matchOnly, MatchOnly: dConf.MatchOnly,
}) })
if dConf.matchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.domain) matchDomains = append(matchDomains, dConf.Domain)
continue continue
} }
searchDomains = append(searchDomains, dConf.domain) searchDomains = append(searchDomains, dConf.Domain)
} }
if config.routeAll { if config.RouteAll {
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true) err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
if err != nil { if err != nil {
return fmt.Errorf("setting link as default dns router, failed with error: %s", err) return fmt.Errorf("setting link as default dns router, failed with error: %s", err)
@ -129,7 +129,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error {
}) })
s.routingAll = true s.routingAll = true
} else if s.routingAll { } else if s.routingAll {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
} }
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -21,10 +22,15 @@ const (
) )
type upstreamClient interface { type upstreamClient interface {
ExchangeContext(ctx context.Context, m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) exchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
} }
type upstreamResolver struct { type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type upstreamResolverBase struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
upstreamClient upstreamClient upstreamClient upstreamClient
@ -40,25 +46,25 @@ type upstreamResolver struct {
reactivate func() reactivate func()
} }
func newUpstreamResolver(parentCTX context.Context) *upstreamResolver { func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase {
ctx, cancel := context.WithCancel(parentCTX) ctx, cancel := context.WithCancel(parentCTX)
return &upstreamResolver{
return &upstreamResolverBase{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
upstreamClient: &dns.Client{},
upstreamTimeout: upstreamTimeout, upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact, failsTillDeact: failsTillDeact,
} }
} }
func (u *upstreamResolver) stop() { func (u *upstreamResolverBase) stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel() u.cancel()
} }
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer u.checkUpstreamFails() defer u.checkUpstreamFails()
log.WithField("question", r.Question[0]).Trace("received an upstream question") log.WithField("question", r.Question[0]).Trace("received an upstream question")
@ -70,10 +76,8 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel() rm, t, err := u.upstreamClient.exchange(upstream, r)
if err != nil { if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) { if err == context.DeadlineExceeded || isTimeout(err) {
@ -83,7 +87,19 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
u.failsCount.Add(1) u.failsCount.Add(1)
log.WithError(err).WithField("upstream", upstream). log.WithError(err).WithField("upstream", upstream).
Error("got an error while querying the upstream") Error("got other error while querying the upstream")
return
}
if rm == nil {
log.WithError(err).WithField("upstream", upstream).
Warn("no response from upstream")
return
}
// those checks need to be independent of each other due to memory address issues
if !rm.Response {
log.WithError(err).WithField("upstream", upstream).
Warn("no response from upstream")
return return
} }
@ -106,7 +122,7 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If fails count is greater that failsTillDeact, upstream resolving // If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter // will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated. // will be reset and upstream will be reactivated.
func (u *upstreamResolver) checkUpstreamFails() { func (u *upstreamResolverBase) checkUpstreamFails() {
u.mutex.Lock() u.mutex.Lock()
defer u.mutex.Unlock() defer u.mutex.Unlock()
@ -118,15 +134,18 @@ func (u *upstreamResolver) checkUpstreamFails() {
case <-u.ctx.Done(): case <-u.ctx.Done():
return return
default: default:
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) // todo test the deactivation logic, it seems to affect the client
u.deactivate() if runtime.GOOS != "ios" {
u.disabled = true log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod)
go u.waitUntilResponse() u.deactivate()
u.disabled = true
go u.waitUntilResponse()
}
} }
} }
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response // waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
func (u *upstreamResolver) waitUntilResponse() { func (u *upstreamResolverBase) waitUntilResponse() {
exponentialBackOff := &backoff.ExponentialBackOff{ exponentialBackOff := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond, InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5, RandomizationFactor: 0.5,
@ -148,10 +167,7 @@ func (u *upstreamResolver) waitUntilResponse() {
var err error var err error
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) _, _, err = u.upstreamClient.exchange(upstream, r)
_, _, err = u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel()
if err == nil { if err == nil {
return nil return nil

View File

@ -0,0 +1,93 @@
//go:build ios
package dns
import (
"context"
"net"
"syscall"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
type upstreamResolverIOS struct {
*upstreamResolverBase
lIP net.IP
lNet *net.IPNet
iIndex int
}
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
lIP: ip,
lNet: net,
iIndex: index,
}
ios.upstreamClient = ios
return ios, nil
}
func (u *upstreamResolverIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
client := &dns.Client{}
upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil {
log.Errorf("error while parsing upstream host: %s", err)
}
upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream)
client = u.getClientPrivate()
}
return client.Exchange(r, upstream)
}
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
func (u *upstreamResolverIOS) getClientPrivate() *dns.Client {
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: u.lIP,
Port: 0, // Let the OS pick a free port
},
Timeout: upstreamTimeout,
Control: func(network, address string, c syscall.RawConn) error {
var operr error
fn := func(s uintptr) {
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
}
if err := c.Control(fn); err != nil {
return err
}
if operr != nil {
log.Errorf("error while setting socket option: %s", operr)
}
return operr
},
}
client := &dns.Client{
Dialer: dialer,
}
return client
}
func getInterfaceIndex(interfaceName string) (int, error) {
iface, err := net.InterfaceByName(interfaceName)
return iface.Index, err
}

View File

@ -0,0 +1,32 @@
//go:build !ios
package dns
import (
"context"
"net"
"time"
"github.com/miekg/dns"
)
type upstreamResolverNonIOS struct {
*upstreamResolverBase
}
func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverNonIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(parentCTX)
nonIOS := &upstreamResolverNonIOS{
upstreamResolverBase: upstreamResolverBase,
}
upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil
}
func (u *upstreamResolverNonIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{}
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
cancel()
return rm, t, err
}

View File

@ -2,6 +2,7 @@ package dns
import ( import (
"context" "context"
"net"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -49,15 +50,6 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
timeout: upstreamTimeout, timeout: upstreamTimeout,
responseShouldBeNil: true, responseShouldBeNil: true,
}, },
//{
// name: "Should Resolve CNAME Record",
// inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME),
//},
//{
// name: "Should Not Write When Not Found A Record",
// inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
// responseShouldBeNil: true,
//},
} }
// should resolve if first upstream times out // should resolve if first upstream times out
// should not write when both fails // should not write when both fails
@ -66,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver := newUpstreamResolver(ctx) resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{})
resolver.upstreamServers = testCase.InputServers resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {
@ -114,12 +106,12 @@ type mockUpstreamResolver struct {
} }
// ExchangeContext mock implementation of ExchangeContext from upstreamResolver // ExchangeContext mock implementation of ExchangeContext from upstreamResolver
func (c mockUpstreamResolver) ExchangeContext(_ context.Context, _ *dns.Msg, _ string) (r *dns.Msg, rtt time.Duration, err error) { func (c mockUpstreamResolver) exchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) {
return c.r, c.rtt, c.err return c.r, c.rtt, c.err
} }
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver := &upstreamResolver{ resolver := &upstreamResolverBase{
ctx: context.TODO(), ctx: context.TODO(),
upstreamClient: &mockUpstreamResolver{ upstreamClient: &mockUpstreamResolver{
err: nil, err: nil,
@ -156,7 +148,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
} }
if !resolver.disabled { if !resolver.disabled {
t.Errorf("resolver should be disabled") t.Errorf("resolver should be Disabled")
return return
} }

View File

@ -197,7 +197,8 @@ func (e *Engine) Start() error {
var routes []*route.Route var routes []*route.Route
if runtime.GOOS == "android" { switch runtime.GOOS {
case "android":
var dnsConfig *nbdns.Config var dnsConfig *nbdns.Config
routes, dnsConfig, err = e.readInitialSettings() routes, dnsConfig, err = e.readInitialSettings()
if err != nil { if err != nil {
@ -207,25 +208,34 @@ func (e *Engine) Start() error {
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener) e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener)
go e.mobileDep.DnsReadyListener.OnReady() go e.mobileDep.DnsReadyListener.OnReady()
} }
} else if e.dnsServer == nil { case "ios":
// todo fix custom address if e.dnsServer == nil {
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress) e.dnsServer = dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager)
if err != nil { }
e.close() default:
return err if e.dnsServer == nil {
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
if err != nil {
e.close()
return err
}
} }
} }
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
if runtime.GOOS == "android" { switch runtime.GOOS {
err = e.wgInterface.CreateOnMobile(iface.MobileIFaceArguments{ case "android":
err = e.wgInterface.CreateOnAndroid(iface.MobileIFaceArguments{
Routes: e.routeManager.InitialRouteRange(), Routes: e.routeManager.InitialRouteRange(),
Dns: e.dnsServer.DnsIP(), Dns: e.dnsServer.DnsIP(),
SearchDomains: e.dnsServer.SearchDomains(), SearchDomains: e.dnsServer.SearchDomains(),
}) })
} else { case "ios":
e.mobileDep.NetworkChangeListener.SetInterfaceIP(wgAddr)
err = e.wgInterface.CreateOniOS(e.mobileDep.FileDescriptor)
default:
err = e.wgInterface.Create() err = e.wgInterface.Create()
} }
if err != nil { if err != nil {
@ -480,7 +490,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
// start SSH server if it wasn't running // start SSH server if it wasn't running
if isNil(e.sshServer) { if isNil(e.sshServer) {
//nil sshServer means it has not yet been started // nil sshServer means it has not yet been started
var err error var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)) fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort))

View File

@ -3,5 +3,6 @@ package listener
// NetworkChangeListener is a callback interface for mobile system // NetworkChangeListener is a callback interface for mobile system
type NetworkChangeListener interface { type NetworkChangeListener interface {
// OnNetworkChanged invoke when network settings has been changed // OnNetworkChanged invoke when network settings has been changed
OnNetworkChanged() OnNetworkChanged(string)
SetInterfaceIP(string)
} }

View File

@ -14,4 +14,6 @@ type MobileDependency struct {
NetworkChangeListener listener.NetworkChangeListener NetworkChangeListener listener.NetworkChangeListener
HostDNSAddresses []string HostDNSAddresses []string
DnsReadyListener dns.ReadyListener DnsReadyListener dns.ReadyListener
DnsManager dns.IosDnsManager
FileDescriptor int32
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -225,6 +226,10 @@ func (conn *Conn) candidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() { if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay} return []ice.CandidateType{ice.CandidateTypeRelay}
} }
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
} }
@ -464,7 +469,7 @@ func (conn *Conn) cleanup() error {
err := conn.statusRecorder.UpdatePeerState(peerState) err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil { if err != nil {
// pretty common error because by that time Engine can already remove the peer and status won't be available. // pretty common error because by that time Engine can already remove the peer and status won't be available.
//todo rethink status updates // todo rethink status updates
log.Debugf("error while updating peer's %s state, err: %v", conn.config.Key, err) log.Debugf("error while updating peer's %s state, err: %v", conn.config.Key, err)
} }

View File

@ -2,6 +2,7 @@ package routemanager
import ( import (
"sort" "sort"
"strings"
"sync" "sync"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@ -50,9 +51,6 @@ func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) {
n.routeRangers = newNets n.routeRangers = newNets
if !n.hasDiff(n.initialRouteRangers, newNets) {
return
}
n.notify() n.notify()
} }
@ -64,7 +62,7 @@ func (n *notifier) notify() {
} }
go func(l listener.NetworkChangeListener) { go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged() l.OnNetworkChanged(strings.Join(n.routeRangers, ","))
}(n.listener) }(n.listener)
} }

View File

@ -1,3 +1,5 @@
//go:build android
package routemanager package routemanager
import ( import (

View File

@ -0,0 +1,15 @@
//go:build ios
package routemanager
import (
"net/netip"
)
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
return nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
return nil
}

View File

@ -1,4 +1,4 @@
//go:build !android //go:build !android && !ios
package routemanager package routemanager

View File

@ -0,0 +1,224 @@
package NetBirdSDK
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
)
// ConnectionListener export internal Listener for mobile
type ConnectionListener interface {
peer.Listener
}
// RouteListener export internal RouteListener for mobile
type NetworkChangeListener interface {
listener.NetworkChangeListener
}
// DnsManager export internal dns Manager for mobile
type DnsManager interface {
dns.IosDnsManager
}
// CustomLogger export internal CustomLogger for mobile
type CustomLogger interface {
Debug(message string)
Info(message string)
Error(message string)
}
func init() {
formatter.SetLogcatFormatter(log.StandardLogger())
}
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
recorder *peer.Status
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
deviceName string
osName string
osVersion string
networkChangeListener listener.NetworkChangeListener
onHostDnsFn func([]string)
dnsManager dns.IosDnsManager
loginComplete bool
}
// NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
return &Client{
cfgFile: cfgFile,
deviceName: deviceName,
osName: osName,
osVersion: osVersion,
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
networkChangeListener: networkChangeListener,
dnsManager: dnsManager,
}
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32, interfaceName string) error {
log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
return err
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
var ctx context.Context
//nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.OsNameCtxKey, c.osName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.OsVersionCtxKey, c.osVersion)
c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel()
c.ctxCancelLock.Unlock()
auth := NewAuthWithConfig(ctx, cfg)
err = auth.Login()
if err != nil {
return err
}
log.Infof("Auth successful")
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
return internal.RunClientiOS(ctx, cfg, c.recorder, fd, c.networkChangeListener, c.dnsManager)
}
// Stop the internal client and free the resources
func (c *Client) Stop() {
c.ctxCancelLock.Lock()
defer c.ctxCancelLock.Unlock()
if c.ctxCancel == nil {
return
}
c.ctxCancel()
}
// ÏSetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel)
}
// getStatusDetails return with the list of the PeerInfos
func (c *Client) GetStatusDetails() *StatusDetails {
fullStatus := c.recorder.GetFullStatus()
peerInfos := make([]PeerInfo, len(fullStatus.Peers))
for n, p := range fullStatus.Peers {
pi := PeerInfo{
p.IP,
p.FQDN,
p.ConnStatus.String(),
}
peerInfos[n] = pi
}
return &StatusDetails{items: peerInfos, fqdn: fullStatus.LocalPeerState.FQDN, ip: fullStatus.LocalPeerState.IP}
}
// SetConnectionListener set the network connection listener
func (c *Client) SetConnectionListener(listener ConnectionListener) {
c.recorder.SetConnectionListener(listener)
}
// RemoveConnectionListener remove connection listener
func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
func (c *Client) IsLoginRequired() bool {
var ctx context.Context
//nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.OsNameCtxKey, c.osName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.OsVersionCtxKey, c.osVersion)
c.ctxCancelLock.Lock()
defer c.ctxCancelLock.Unlock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
cfg, _ := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
needsLogin, _ := internal.IsLoginRequired(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg.SSHKey)
return needsLogin
}
func (c *Client) LoginForMobile() string {
var ctx context.Context
//nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.OsNameCtxKey, c.osName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.OsVersionCtxKey, c.osVersion)
c.ctxCancelLock.Lock()
defer c.ctxCancelLock.Unlock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
cfg, _ := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile,
})
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
if err != nil {
return err.Error()
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
if err != nil {
return err.Error()
}
// This could cause a potential race condition with loading the extension which need to be handled on swift side
go func() {
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return
}
jwtToken := tokenInfo.GetTokenToUse()
_ = internal.Login(ctx, cfg, "", jwtToken)
c.loginComplete = true
}()
return flowInfo.VerificationURIComplete
}
func (c *Client) IsLoginComplete() bool {
return c.loginComplete
}
func (c *Client) ClearLoginComplete() {
c.loginComplete = false
}

View File

@ -0,0 +1,5 @@
package NetBirdSDK
import _ "golang.org/x/mobile/bind"
// to keep our CI/CD that checks go.mod and go.sum files happy, we need to import the package above

View File

@ -0,0 +1,10 @@
package NetBirdSDK
import (
"github.com/netbirdio/netbird/util"
)
// InitializeLog initializes the log file.
func InitializeLog(logLevel string, filePath string) error {
return util.InitLog(logLevel, filePath)
}

View File

@ -0,0 +1,159 @@
package NetBirdSDK
import (
"context"
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/system"
)
// SSOListener is async listener for mobile framework
type SSOListener interface {
OnSuccess(bool)
OnError(error)
}
// ErrListener is async listener for mobile framework
type ErrListener interface {
OnSuccess()
OnError(error)
}
// URLOpener it is a callback interface. The Open function will be triggered if
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
}
// Auth can register or login new client
type Auth struct {
ctx context.Context
config *internal.Config
cfgPath string
}
// NewAuth instantiate Auth struct and validate the management URL
func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
inputCfg := internal.ConfigInput{
ManagementURL: mgmURL,
}
cfg, err := internal.CreateInMemoryConfig(inputCfg)
if err != nil {
return nil, err
}
return &Auth{
ctx: context.Background(),
config: cfg,
cfgPath: cfgPath,
}, nil
}
// NewAuthWithConfig instantiate Auth based on existing config
func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth {
return &Auth{
ctx: ctx,
config: config,
}
}
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
// is not supported and returns false without saving the configuration. For other errors return false.
func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
supportsSSO = false
err = nil
}
return err
}
return err
})
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = internal.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
return internal.WriteOutConfig(a.cfgPath, a.config)
}
func (a *Auth) Login() error {
var needsLogin bool
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
return
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
if needsLogin {
return fmt.Errorf("Not authenticated")
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@ -0,0 +1,50 @@
package NetBirdSDK
// PeerInfo describe information about the peers. It designed for the UI usage
type PeerInfo struct {
IP string
FQDN string
ConnStatus string // Todo replace to enum
}
// PeerInfoCollection made for Java layer to get non default types as collection
type PeerInfoCollection interface {
Add(s string) PeerInfoCollection
Get(i int) string
Size() int
GetFQDN() string
GetIP() string
}
// StatusDetails is the implementation of the PeerInfoCollection
type StatusDetails struct {
items []PeerInfo
fqdn string
ip string
}
// Add new PeerInfo to the collection
func (array StatusDetails) Add(s PeerInfo) StatusDetails {
array.items = append(array.items, s)
return array
}
// Get return an element of the collection
func (array StatusDetails) Get(i int) *PeerInfo {
return &array.items[i]
}
// Size return with the size of the collection
func (array StatusDetails) Size() int {
return len(array.items)
}
// GetFQDN return with the FQDN of the local peer
func (array StatusDetails) GetFQDN() string {
return array.fqdn
}
// GetIP return with the IP of the local peer
func (array StatusDetails) GetIP() string {
return array.ip
}

View File

@ -0,0 +1,78 @@
package NetBirdSDK
import (
"github.com/netbirdio/netbird/client/internal"
)
// Preferences export a subset of the internal config for gomobile
type Preferences struct {
configInput internal.ConfigInput
}
// NewPreferences create new Preferences instance
func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{
ConfigPath: configPath,
}
return &Preferences{ci}
}
// GetManagementURL read url from config file
func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return "", err
}
return cfg.ManagementURL.String(), err
}
// SetManagementURL store the given url and wait for commit
func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url
}
// GetAdminURL read url from config file
func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return "", err
}
return cfg.AdminURL.String(), err
}
// SetAdminURL store the given url and wait for commit
func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url
}
// GetPreSharedKey read preshared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return "", err
}
return cfg.PreSharedKey, err
}
// SetPreSharedKey store the given key and wait for commit
func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key
}
// Commit write out the changes into config file
func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput)
return err
}

View File

@ -0,0 +1,120 @@
package NetBirdSDK
import (
"path/filepath"
"testing"
"github.com/netbirdio/netbird/client/internal"
)
func TestPreferences_DefaultValues(t *testing.T) {
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
p := NewPreferences(cfgFile)
defaultVar, err := p.GetAdminURL()
if err != nil {
t.Fatalf("failed to read default value: %s", err)
}
if defaultVar != internal.DefaultAdminURL {
t.Errorf("invalid default admin url: %s", defaultVar)
}
defaultVar, err = p.GetManagementURL()
if err != nil {
t.Fatalf("failed to read default management URL: %s", err)
}
if defaultVar != internal.DefaultManagementURL {
t.Errorf("invalid default management url: %s", defaultVar)
}
var preSharedKey string
preSharedKey, err = p.GetPreSharedKey()
if err != nil {
t.Fatalf("failed to read default preshared key: %s", err)
}
if preSharedKey != "" {
t.Errorf("invalid preshared key: %s", preSharedKey)
}
}
func TestPreferences_ReadUncommitedValues(t *testing.T) {
exampleString := "exampleString"
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
p := NewPreferences(cfgFile)
p.SetAdminURL(exampleString)
resp, err := p.GetAdminURL()
if err != nil {
t.Fatalf("failed to read admin url: %s", err)
}
if resp != exampleString {
t.Errorf("unexpected admin url: %s", resp)
}
p.SetManagementURL(exampleString)
resp, err = p.GetManagementURL()
if err != nil {
t.Fatalf("failed to read management url: %s", err)
}
if resp != exampleString {
t.Errorf("unexpected management url: %s", resp)
}
p.SetPreSharedKey(exampleString)
resp, err = p.GetPreSharedKey()
if err != nil {
t.Fatalf("failed to read preshared key: %s", err)
}
if resp != exampleString {
t.Errorf("unexpected preshared key: %s", resp)
}
}
func TestPreferences_Commit(t *testing.T) {
exampleURL := "https://myurl.com:443"
examplePresharedKey := "topsecret"
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
p := NewPreferences(cfgFile)
p.SetAdminURL(exampleURL)
p.SetManagementURL(exampleURL)
p.SetPreSharedKey(examplePresharedKey)
err := p.Commit()
if err != nil {
t.Fatalf("failed to save changes: %s", err)
}
p = NewPreferences(cfgFile)
resp, err := p.GetAdminURL()
if err != nil {
t.Fatalf("failed to read admin url: %s", err)
}
if resp != exampleURL {
t.Errorf("unexpected admin url: %s", resp)
}
resp, err = p.GetManagementURL()
if err != nil {
t.Fatalf("failed to read management url: %s", err)
}
if resp != exampleURL {
t.Errorf("unexpected management url: %s", resp)
}
resp, err = p.GetPreSharedKey()
if err != nil {
t.Fatalf("failed to read preshared key: %s", err)
}
if resp != examplePresharedKey {
t.Errorf("unexpected preshared key: %s", resp)
}
}

View File

@ -12,6 +12,12 @@ import (
// DeviceNameCtxKey context key for device name // DeviceNameCtxKey context key for device name
const DeviceNameCtxKey = "deviceName" const DeviceNameCtxKey = "deviceName"
// OsVersionCtxKey context key for operating system version
const OsVersionCtxKey = "OsVersion"
// OsNameCtxKey context key for operating system name
const OsNameCtxKey = "OsName"
// Info is an object that contains machine information // Info is an object that contains machine information
// Most of the code is taken from https://github.com/matishsiao/goInfo // Most of the code is taken from https://github.com/matishsiao/goInfo
type Info struct { type Info struct {

View File

@ -1,3 +1,6 @@
//go:build !ios
// +build !ios
package system package system
import ( import (

44
client/system/info_ios.go Normal file
View File

@ -0,0 +1,44 @@
//go:build ios
// +build ios
package system
import (
"context"
"runtime"
"github.com/netbirdio/netbird/version"
)
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
// Convert fixed-size byte arrays to Go strings
sysName := extractOsName(ctx, "sysName")
swVersion := extractOsVersion(ctx, "swVersion")
gio := &Info{Kernel: sysName, OSVersion: swVersion, Core: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname = extractDeviceName(ctx, "hostname")
gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx)
return gio
}
// extractOsVersion extracts operating system version from context or returns the default
func extractOsVersion(ctx context.Context, defaultName string) string {
v, ok := ctx.Value(OsVersionCtxKey).(string)
if !ok {
return defaultName
}
return v
}
// extractOsName extracts operating system name from context or returns the default
func extractOsName(ctx context.Context, defaultName string) string {
v, ok := ctx.Value(OsNameCtxKey).(string)
if !ok {
return defaultName
}
return v
}

View File

@ -28,14 +28,20 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter
return wgIFace, nil return wgIFace, nil
} }
// CreateOnMobile creates a new Wireguard interface, sets a given IP and brings it up. // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error { func (w *WGIface) CreateOnAndroid(mIFaceArgs MobileIFaceArguments) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
return w.tun.Create(mIFaceArgs) return w.tun.Create(mIFaceArgs)
} }
// CreateOniOS creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOniOS(tunFd int32) error {
return fmt.Errorf("this function has not implemented on mobile")
}
// Create this function make sense on mobile only // Create this function make sense on mobile only
func (w *WGIface) Create() error { func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on mobile") return fmt.Errorf("this function has not implemented on mobile")

51
iface/iface_ios.go Normal file
View File

@ -0,0 +1,51 @@
//go:build ios
// +build ios
package iface
import (
"fmt"
"sync"
"github.com/pion/transport/v2"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{
mu: sync.Mutex{},
}
wgAddress, err := parseWGAddress(address)
if err != nil {
return wgIFace, err
}
tun := newTunDevice(ifaceName, wgAddress, mtu, tunAdapter, transportNet)
wgIFace.tun = tun
wgIFace.configurer = newWGConfigurer(tun)
wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil
}
// CreateOniOS creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOniOS(tunFd int32) error {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.Create(tunFd)
}
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(mIFaceArgs MobileIFaceArguments) error {
return fmt.Errorf("this function has not implemented on mobile")
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on mobile")
}

View File

@ -1,4 +1,5 @@
//go:build !android //go:build !android && !ios
// +build !android,!ios
package iface package iface
@ -27,8 +28,13 @@ func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter
return wgIFace, nil return wgIFace, nil
} }
// CreateOnMobile this function make sense on mobile only // CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error { func (w *WGIface) CreateOnAndroid(mIFaceArgs MobileIFaceArguments) error {
return fmt.Errorf("this function has not implemented on non mobile")
}
// CreateOniOS this function make sense on mobile only
func (w *WGIface) CreateOniOS(tunFd int32) error {
return fmt.Errorf("this function has not implemented on non mobile") return fmt.Errorf("this function has not implemented on non mobile")
} }

View File

@ -1,3 +1,6 @@
//go:build android || ios
// +build android ios
package iface package iface
import ( import (

View File

@ -1,3 +1,6 @@
//go:build android
// +build android
package iface package iface
import ( import (
@ -56,7 +59,7 @@ func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong. // without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
//t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()
err = t.device.Up() err = t.device.Up()
if err != nil { if err != nil {

View File

@ -1,3 +1,6 @@
//go:build !ios
// +build !ios
package iface package iface
import ( import (

101
iface/tun_ios.go Normal file
View File

@ -0,0 +1,101 @@
//go:build ios
// +build ios
package iface
import (
"os"
"github.com/pion/transport/v2"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind"
)
type tunDevice struct {
address WGAddress
mtu int
tunAdapter TunAdapter
iceBind *bind.ICEBind
fd int
name string
device *device.Device
wrapper *DeviceWrapper
}
func newTunDevice(name string, address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
return &tunDevice{
name: name,
address: address,
mtu: mtu,
tunAdapter: tunAdapter,
iceBind: bind.NewICEBind(transportNet),
}
}
func (t *tunDevice) Create(tunFd int32) error {
log.Infof("create tun interface")
dupTunFd, err := unix.Dup(int(tunFd))
if err != nil {
log.Errorf("Unable to dup tun fd: %v", err)
return err
}
err = unix.SetNonblock(dupTunFd, true)
if err != nil {
log.Errorf("Unable to set tun fd as non blocking: %v", err)
unix.Close(dupTunFd)
return err
}
tun, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
if err != nil {
log.Errorf("Unable to create new tun device from fd: %v", err)
unix.Close(dupTunFd)
return err
}
t.wrapper = newDeviceWrapper(tun)
log.Debug("Attaching to interface")
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
err = t.device.Up()
if err != nil {
t.device.Close()
return err
}
log.Debugf("device is ready to use: %s", t.name)
return nil
}
func (t *tunDevice) Device() *device.Device {
return t.device
}
func (t *tunDevice) DeviceName() string {
return t.name
}
func (t *tunDevice) WgAddress() WGAddress {
return t.address
}
func (t *tunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
func (t *tunDevice) Close() (err error) {
if t.device != nil {
t.device.Close()
}
return
}

View File

@ -1,4 +1,4 @@
//go:build (linux || darwin) && !android //go:build (linux || darwin) && !android && !ios
package iface package iface

View File

@ -1,12 +1,17 @@
//go:build ios || android
// +build ios android
package iface package iface
import ( import (
"encoding/hex"
"errors" "errors"
"fmt"
"net" "net"
"strings"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -42,7 +47,7 @@ func (c *wGConfigurer) configureInterface(privateKey string, port int) error {
} }
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
//parse allowed ips // parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps) _, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil { if err != nil {
return err return err
@ -109,6 +114,52 @@ func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) return c.tunDevice.Device().IpcSet(toWgUserspaceString(config))
} }
func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { func (c *wGConfigurer) removeAllowedIP(peerKey string, ip string) error {
return errFuncNotImplemented ipc, err := c.tunDevice.Device().IpcGet()
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
hexKey := hex.EncodeToString(peerKeyParsed[:])
lines := strings.Split(ipc, "\n")
output := ""
foundPeer := false
removedAllowedIP := false
for _, line := range lines {
line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, reset the flag.
if strings.HasPrefix(line, "public_key=") && foundPeer {
foundPeer = false
}
// Identify the peer with the specific public key
if line == fmt.Sprintf("public_key=%s", hexKey) {
foundPeer = true
}
// If we're within the details of the found peer and find the specific allowed IP, skip this line
if foundPeer && line == "allowed_ip="+ip {
removedAllowedIP = true
continue
}
// Append the line to the output string
if strings.HasPrefix(line, "private_key=") || strings.HasPrefix(line, "listen_port=") ||
strings.HasPrefix(line, "public_key=") || strings.HasPrefix(line, "preshared_key=") ||
strings.HasPrefix(line, "endpoint=") || strings.HasPrefix(line, "persistent_keepalive_interval=") ||
strings.HasPrefix(line, "allowed_ip=") {
output += line + "\n"
}
}
if !removedAllowedIP {
return fmt.Errorf("allowedIP not found")
} else {
return c.tunDevice.Device().IpcSet(output)
}
} }

View File

@ -1,4 +1,4 @@
//go:build !android //go:build !android && !ios
package iface package iface
@ -44,7 +44,7 @@ func (c *wGConfigurer) configureInterface(privateKey string, port int) error {
} }
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
//parse allowed ips // parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps) _, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil { if err != nil {
return err return err