diff --git a/client/internal/connect.go b/client/internal/connect.go index 79f97e87f..d8784c0c8 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -43,6 +43,15 @@ func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.S return runClient(ctx, config, statusRecorder, mobileDependency) } +func RunClientiOS(ctx context.Context, config *Config, statusRecorder *peer.Status, fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager) error { + mobileDependency := MobileDependency{ + FileDescriptor: fileDescriptor, + NetworkChangeListener: networkChangeListener, + DnsManager: dnsManager, + } + return runClient(ctx, config, statusRecorder, mobileDependency) +} + func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error { log.Infof("starting NetBird client version %s", version.NetbirdVersion()) diff --git a/client/internal/dns/file_linux.go b/client/internal/dns/file_linux.go index 81b16459b..f49e9fb93 100644 --- a/client/internal/dns/file_linux.go +++ b/client/internal/dns/file_linux.go @@ -35,14 +35,14 @@ func (f *fileConfigurator) supportCustomPort() bool { return false } -func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error { +func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { backupFileExist := false _, err := os.Stat(fileDefaultResolvConfBackupLocation) if err == nil { backupFileExist = true } - if !config.routeAll { + if !config.RouteAll { if backupFileExist { err = f.restore() if err != nil { @@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error { buf := prepareResolvConfContent( searchDomainList, - append([]string{config.serverIP}, nameServers...), + append([]string{config.ServerIP}, nameServers...), others) log.Debugf("creating managed file %s", defaultResolvConfPath) @@ -138,14 +138,14 @@ func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes return buf } -func searchDomains(config hostDNSConfig) []string { +func searchDomains(config HostDNSConfig) []string { listOfDomains := make([]string, 0) - for _, dConf := range config.domains { - if dConf.matchOnly || dConf.disabled { + for _, dConf := range config.Domains { + if dConf.MatchOnly || dConf.Disabled { continue } - listOfDomains = append(listOfDomains, dConf.domain) + listOfDomains = append(listOfDomains, dConf.Domain) } return listOfDomains } @@ -214,7 +214,7 @@ func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, othe return } -// merge search domains lists and cut off the list if it is too long +// merge search Domains lists and cut off the list if it is too long func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string { lineSize := len("search") searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains)) @@ -225,14 +225,14 @@ func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) return searchDomainsList } -// validateAndFillSearchDomains checks if the search domains list is not too long and if the line is not too long +// validateAndFillSearchDomains checks if the search Domains list is not too long and if the line is not too long // extend s slice with vs elements // return with the number of characters in the searchDomains line func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int { for _, sd := range vs { tmpCharsNumber := initialLineChars + 1 + len(sd) if tmpCharsNumber > fileMaxLineCharsLimit { - // lets log all skipped domains + // lets log all skipped Domains log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd) continue } @@ -240,7 +240,7 @@ func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string initialLineChars = tmpCharsNumber if len(*s) >= fileMaxNumberOfSearchDomains { - // lets log all skipped domains + // lets log all skipped Domains log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd) continue } diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index 4fd164c45..ee50b39d0 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -8,31 +8,31 @@ import ( ) type hostManager interface { - applyDNSConfig(config hostDNSConfig) error + applyDNSConfig(config HostDNSConfig) error restoreHostDNS() error supportCustomPort() bool } -type hostDNSConfig struct { - domains []domainConfig - routeAll bool - serverIP string - serverPort int +type HostDNSConfig struct { + Domains []DomainConfig `json:"domains"` + RouteAll bool `json:"routeAll"` + ServerIP string `json:"serverIP"` + ServerPort int `json:"serverPort"` } -type domainConfig struct { - disabled bool - domain string - matchOnly bool +type DomainConfig struct { + Disabled bool `json:"disabled"` + Domain string `json:"domain"` + MatchOnly bool `json:"matchOnly"` } type mockHostConfigurator struct { - applyDNSConfigFunc func(config hostDNSConfig) error + applyDNSConfigFunc func(config HostDNSConfig) error restoreHostDNSFunc func() error supportCustomPortFunc func() bool } -func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error { +func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error { if m.applyDNSConfigFunc != nil { return m.applyDNSConfigFunc(config) } @@ -55,38 +55,38 @@ func (m *mockHostConfigurator) supportCustomPort() bool { func newNoopHostMocker() hostManager { return &mockHostConfigurator{ - applyDNSConfigFunc: func(config hostDNSConfig) error { return nil }, + applyDNSConfigFunc: func(config HostDNSConfig) error { return nil }, restoreHostDNSFunc: func() error { return nil }, supportCustomPortFunc: func() bool { return true }, } } -func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig { - config := hostDNSConfig{ - routeAll: false, - serverIP: ip, - serverPort: port, +func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostDNSConfig { + config := HostDNSConfig{ + RouteAll: false, + ServerIP: ip, + ServerPort: port, } for _, nsConfig := range dnsConfig.NameServerGroups { if len(nsConfig.NameServers) == 0 { continue } if nsConfig.Primary { - config.routeAll = true + config.RouteAll = true } for _, domain := range nsConfig.Domains { - config.domains = append(config.domains, domainConfig{ - domain: strings.TrimSuffix(domain, "."), - matchOnly: !nsConfig.SearchDomainsEnabled, + config.Domains = append(config.Domains, DomainConfig{ + Domain: strings.TrimSuffix(domain, "."), + MatchOnly: !nsConfig.SearchDomainsEnabled, }) } } for _, customZone := range dnsConfig.CustomZones { - config.domains = append(config.domains, domainConfig{ - domain: strings.TrimSuffix(customZone.Domain, "."), - matchOnly: false, + config.Domains = append(config.Domains, DomainConfig{ + Domain: strings.TrimSuffix(customZone.Domain, "."), + MatchOnly: false, }) } diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go index 4ab7b32d8..169cc7c47 100644 --- a/client/internal/dns/host_android.go +++ b/client/internal/dns/host_android.go @@ -7,7 +7,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) { return &androidHostManager{}, nil } -func (a androidHostManager) applyDNSConfig(config hostDNSConfig) error { +func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error { return nil } diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index f02c32c22..0f16b7828 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -1,3 +1,5 @@ +//go:build !ios + package dns import ( @@ -42,11 +44,11 @@ func (s *systemConfigurator) supportCustomPort() bool { return true } -func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error { +func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { var err error - if config.routeAll { - err = s.addDNSSetupForAll(config.serverIP, config.serverPort) + if config.RouteAll { + err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort) if err != nil { return err } @@ -56,7 +58,7 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error { return err } s.primaryServiceID = "" - log.Infof("removed %s:%d as main DNS resolver for this peer", config.serverIP, config.serverPort) + log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort) } var ( @@ -64,20 +66,20 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error { matchDomains []string ) - for _, dConf := range config.domains { - if dConf.disabled { + for _, dConf := range config.Domains { + if dConf.Disabled { continue } - if dConf.matchOnly { - matchDomains = append(matchDomains, dConf.domain) + if dConf.MatchOnly { + matchDomains = append(matchDomains, dConf.Domain) continue } - searchDomains = append(searchDomains, dConf.domain) + searchDomains = append(searchDomains, dConf.Domain) } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) if len(matchDomains) != 0 { - err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.serverIP, config.serverPort) + err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) } else { log.Infof("removing match domains from the system") err = s.removeKeyFromSystemConfig(matchKey) @@ -88,7 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error { searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { - err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.serverIP, config.serverPort) + err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort) } else { log.Infof("removing search domains from the system") err = s.removeKeyFromSystemConfig(searchKey) diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go new file mode 100644 index 000000000..5058d6ba4 --- /dev/null +++ b/client/internal/dns/host_ios.go @@ -0,0 +1,37 @@ +package dns + +import ( + "encoding/json" + + log "github.com/sirupsen/logrus" +) + +type iosHostManager struct { + dnsManager IosDnsManager + config HostDNSConfig +} + +func newHostManager(dnsManager IosDnsManager) (hostManager, error) { + return &iosHostManager{ + dnsManager: dnsManager, + }, nil +} + +func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error { + jsonData, err := json.Marshal(config) + if err != nil { + return err + } + jsonString := string(jsonData) + log.Debugf("Applying DNS settings: %s", jsonString) + a.dnsManager.ApplyDns(jsonString) + return nil +} + +func (a iosHostManager) restoreHostDNS() error { + return nil +} + +func (a iosHostManager) supportCustomPort() bool { + return false +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 3814be00b..3a574c4ee 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -43,10 +43,10 @@ func (s *registryConfigurator) supportCustomPort() bool { return false } -func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error { +func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { var err error - if config.routeAll { - err = r.addDNSSetupForAll(config.serverIP) + if config.RouteAll { + err = r.addDNSSetupForAll(config.ServerIP) if err != nil { return err } @@ -56,7 +56,7 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error { return err } r.routingAll = false - log.Infof("removed %s as main DNS forwarder for this peer", config.serverIP) + log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } var ( @@ -64,18 +64,18 @@ func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error { matchDomains []string ) - for _, dConf := range config.domains { - if dConf.disabled { + for _, dConf := range config.Domains { + if dConf.Disabled { continue } - if !dConf.matchOnly { - searchDomains = append(searchDomains, dConf.domain) + if !dConf.MatchOnly { + searchDomains = append(searchDomains, dConf.Domain) } - matchDomains = append(matchDomains, "."+dConf.domain) + matchDomains = append(matchDomains, "."+dConf.Domain) } if len(matchDomains) != 0 { - err = r.addDNSMatchPolicy(matchDomains, config.serverIP) + err = r.addDNSMatchPolicy(matchDomains, config.ServerIP) } else { err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath) } diff --git a/client/internal/dns/local_test.go b/client/internal/dns/local_test.go index db69d9ad8..b62cd66a9 100644 --- a/client/internal/dns/local_test.go +++ b/client/internal/dns/local_test.go @@ -1,10 +1,12 @@ package dns import ( - "github.com/miekg/dns" - nbdns "github.com/netbirdio/netbird/dns" "strings" "testing" + + "github.com/miekg/dns" + + nbdns "github.com/netbirdio/netbird/dns" ) func TestLocalResolver_ServeDNS(t *testing.T) { diff --git a/client/internal/dns/mockServer.go b/client/internal/dns/mock_server.go similarity index 98% rename from client/internal/dns/mockServer.go rename to client/internal/dns/mock_server.go index 3534fc0c3..ed4116b9d 100644 --- a/client/internal/dns/mockServer.go +++ b/client/internal/dns/mock_server.go @@ -33,7 +33,7 @@ func (m *MockServer) DnsIP() string { } func (m *MockServer) OnUpdatedHostDNSServer(strings []string) { - //TODO implement me + // TODO implement me panic("implement me") } diff --git a/client/internal/dns/network_manager_linux.go b/client/internal/dns/network_manager_linux.go index d5c2f60b2..b4a7a2514 100644 --- a/client/internal/dns/network_manager_linux.go +++ b/client/internal/dns/network_manager_linux.go @@ -12,8 +12,9 @@ import ( "github.com/godbus/dbus/v5" "github.com/hashicorp/go-version" "github.com/miekg/dns" - nbversion "github.com/netbirdio/netbird/version" log "github.com/sirupsen/logrus" + + nbversion "github.com/netbirdio/netbird/version" ) const ( @@ -93,7 +94,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool { return false } -func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { +func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { connSettings, configVersion, err := n.getAppliedConnectionSettings() if err != nil { return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err) @@ -101,7 +102,7 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er connSettings.cleanDeprecatedSettings() - dnsIP, err := netip.ParseAddr(config.serverIP) + dnsIP, err := netip.ParseAddr(config.ServerIP) if err != nil { return fmt.Errorf("unable to parse ip address, error: %s", err) } @@ -111,33 +112,33 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) er searchDomains []string matchDomains []string ) - for _, dConf := range config.domains { - if dConf.disabled { + for _, dConf := range config.Domains { + if dConf.Disabled { continue } - if dConf.matchOnly { - matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain)) + if dConf.MatchOnly { + matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain)) continue } - searchDomains = append(searchDomains, dns.Fqdn(dConf.domain)) + searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain)) } newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic priority := networkManagerDbusSearchDomainOnlyPriority switch { - case config.routeAll: + case config.RouteAll: priority = networkManagerDbusPrimaryDNSPriority newDomainList = append(newDomainList, "~.") if !n.routingAll { - log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) } case len(matchDomains) > 0: priority = networkManagerDbusWithMatchDomainPriority } if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll { - log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) n.routingAll = false } diff --git a/client/internal/dns/notifier.go b/client/internal/dns/notifier.go index 85c270e58..35cb6ff82 100644 --- a/client/internal/dns/notifier.go +++ b/client/internal/dns/notifier.go @@ -52,6 +52,6 @@ func (n *notifier) notify() { } go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged() + l.OnNetworkChanged("") }(n.listener) } diff --git a/client/internal/dns/resolvconf_linux.go b/client/internal/dns/resolvconf_linux.go index 1ae2de3dd..54bdeae12 100644 --- a/client/internal/dns/resolvconf_linux.go +++ b/client/internal/dns/resolvconf_linux.go @@ -39,9 +39,9 @@ func (r *resolvconf) supportCustomPort() bool { return false } -func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error { +func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { var err error - if !config.routeAll { + if !config.RouteAll { err = r.restoreHostDNS() if err != nil { log.Error(err) @@ -54,7 +54,7 @@ func (r *resolvconf) applyDNSConfig(config hostDNSConfig) error { buf := prepareResolvConfContent( searchDomainList, - append([]string{config.serverIP}, r.originalNameServers...), + append([]string{config.ServerIP}, r.originalNameServers...), r.othersConfigs) err = r.applyConfig(buf) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f984f02ec..439c27a27 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -19,6 +19,11 @@ type ReadyListener interface { OnReady() } +// IosDnsManager is a dns manager interface for iOS +type IosDnsManager interface { + ApplyDns(string) +} + // Server is a dns server interface type Server interface { Initialize() error @@ -43,7 +48,7 @@ type DefaultServer struct { hostManager hostManager updateSerial uint64 previousConfigHash uint64 - currentConfig hostDNSConfig + currentConfig HostDNSConfig // permanent related properties permanent bool @@ -52,6 +57,7 @@ type DefaultServer struct { // make sense on mobile only searchDomainNotifier *notifier + iosDnsManager IosDnsManager } type handlerWithStop interface { @@ -99,6 +105,13 @@ func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, return ds } +// NewDefaultServerIos returns a new dns server. It optimized for ios +func NewDefaultServerIos(ctx context.Context, wgInterface WGIface, iosDnsManager IosDnsManager) *DefaultServer { + ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface)) + ds.iosDnsManager = iosDnsManager + return ds +} + func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ @@ -131,8 +144,8 @@ func (s *DefaultServer) Initialize() (err error) { } } - s.hostManager, err = newHostManager(s.wgInterface) - return + s.hostManager, err = s.initialize() + return err } // DnsIP returns the DNS resolver server IP address @@ -223,20 +236,20 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro func (s *DefaultServer) SearchDomains() []string { var searchDomains []string - for _, dConf := range s.currentConfig.domains { - if dConf.disabled { + for _, dConf := range s.currentConfig.Domains { + if dConf.Disabled { continue } - if dConf.matchOnly { + if dConf.MatchOnly { continue } - searchDomains = append(searchDomains, dConf.domain) + searchDomains = append(searchDomains, dConf.Domain) } return searchDomains } func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { - // is the service should be disabled, we stop the listener or fake resolver + // is the service should be Disabled, we stop the listener or fake resolver // and proceed with a regular update to clean up the handlers and records if update.ServiceEnable { _ = s.service.Listen() @@ -262,7 +275,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver") - hostUpdate.routeAll = false + hostUpdate.RouteAll = false } if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil { @@ -312,7 +325,10 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam continue } - handler := newUpstreamResolver(s.ctx) + handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network) + if err != nil { + return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err) + } for _, ns := range nsGroup.NameServers { if ns.NSType != nbdns.UDPNameServerType { log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", @@ -445,14 +461,14 @@ func (s *DefaultServer) upstreamCallbacks( } if nsGroup.Primary { removeIndex[nbdns.RootZone] = -1 - s.currentConfig.routeAll = false + s.currentConfig.RouteAll = false } - for i, item := range s.currentConfig.domains { - if _, found := removeIndex[item.domain]; found { - s.currentConfig.domains[i].disabled = true - s.service.DeregisterMux(item.domain) - removeIndex[item.domain] = i + for i, item := range s.currentConfig.Domains { + if _, found := removeIndex[item.Domain]; found { + s.currentConfig.Domains[i].Disabled = true + s.service.DeregisterMux(item.Domain) + removeIndex[item.Domain] = i } } if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { @@ -464,28 +480,32 @@ func (s *DefaultServer) upstreamCallbacks( defer s.mux.Unlock() for domain, i := range removeIndex { - if i == -1 || i >= len(s.currentConfig.domains) || s.currentConfig.domains[i].domain != domain { + if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain { continue } - s.currentConfig.domains[i].disabled = false + s.currentConfig.Domains[i].Disabled = false s.service.RegisterMux(domain, handler) } l := log.WithField("nameservers", nsGroup.NameServers) - l.Debug("reactivate temporary disabled nameserver group") + l.Debug("reactivate temporary Disabled nameserver group") if nsGroup.Primary { - s.currentConfig.routeAll = true + s.currentConfig.RouteAll = true } if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { - l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") + l.WithError(err).Error("reactivate temporary Disabled nameserver group, DNS update apply") } } return } func (s *DefaultServer) addHostRootZone() { - handler := newUpstreamResolver(s.ctx) + handler, err := newUpstreamResolver(s.ctx, s.wgInterface.Name(), s.wgInterface.Address().IP, s.wgInterface.Address().Network) + if err != nil { + log.Errorf("unable to create a new upstream resolver, error: %v", err) + return + } handler.upstreamServers = make([]string, len(s.hostsDnsList)) for n, ua := range s.hostsDnsList { a, err := netip.ParseAddr(ua) diff --git a/client/internal/dns/server_android.go b/client/internal/dns/server_android.go new file mode 100644 index 000000000..5e1494e9e --- /dev/null +++ b/client/internal/dns/server_android.go @@ -0,0 +1,5 @@ +package dns + +func (s *DefaultServer) initialize() (manager hostManager, err error) { + return newHostManager(s.wgInterface) +} diff --git a/client/internal/dns/server_darwin.go b/client/internal/dns/server_darwin.go new file mode 100644 index 000000000..feeb69352 --- /dev/null +++ b/client/internal/dns/server_darwin.go @@ -0,0 +1,7 @@ +//go:build !ios + +package dns + +func (s *DefaultServer) initialize() (manager hostManager, err error) { + return newHostManager(s.wgInterface) +} diff --git a/client/internal/dns/server_ios.go b/client/internal/dns/server_ios.go new file mode 100644 index 000000000..d04e7ab44 --- /dev/null +++ b/client/internal/dns/server_ios.go @@ -0,0 +1,5 @@ +package dns + +func (s *DefaultServer) initialize() (manager hostManager, err error) { + return newHostManager(s.iosDnsManager) +} diff --git a/client/internal/dns/server_linux.go b/client/internal/dns/server_linux.go new file mode 100644 index 000000000..7d7027839 --- /dev/null +++ b/client/internal/dns/server_linux.go @@ -0,0 +1,7 @@ +//go:build !android + +package dns + +func (s *DefaultServer) initialize() (manager hostManager, err error) { + return newHostManager(s.wgInterface) +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 875a1a46f..67d411df5 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -527,8 +527,8 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { registeredMap: make(registrationMap), }, hostManager: hostManager, - currentConfig: hostDNSConfig{ - domains: []domainConfig{ + currentConfig: HostDNSConfig{ + Domains: []DomainConfig{ {false, "domain0", false}, {false, "domain1", false}, {false, "domain2", false}, @@ -537,13 +537,13 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { } var domainsUpdate string - hostManager.applyDNSConfigFunc = func(config hostDNSConfig) error { + hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error { domains := []string{} - for _, item := range config.domains { - if item.disabled { + for _, item := range config.Domains { + if item.Disabled { continue } - domains = append(domains, item.domain) + domains = append(domains, item.Domain) } domainsUpdate = strings.Join(domains, ",") return nil @@ -559,11 +559,11 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { deactivate() expected := "domain0,domain2" domains := []string{} - for _, item := range server.currentConfig.domains { - if item.disabled { + for _, item := range server.currentConfig.Domains { + if item.Disabled { continue } - domains = append(domains, item.domain) + domains = append(domains, item.Domain) } got := strings.Join(domains, ",") if expected != got { @@ -573,11 +573,11 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { reactivate() expected = "domain0,domain1,domain2" domains = []string{} - for _, item := range server.currentConfig.domains { - if item.disabled { + for _, item := range server.currentConfig.Domains { + if item.Disabled { continue } - domains = append(domains, item.domain) + domains = append(domains, item.Domain) } got = strings.Join(domains, ",") if expected != got { diff --git a/client/internal/dns/server_windows.go b/client/internal/dns/server_windows.go new file mode 100644 index 000000000..5e1494e9e --- /dev/null +++ b/client/internal/dns/server_windows.go @@ -0,0 +1,5 @@ +package dns + +func (s *DefaultServer) initialize() (manager hostManager, err error) { + return newHostManager(s.wgInterface) +} diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 0358b0251..3cd4342ad 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -81,8 +81,8 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { return true } -func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { - parsedIP, err := netip.ParseAddr(config.serverIP) +func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { + parsedIP, err := netip.ParseAddr(config.ServerIP) if err != nil { return fmt.Errorf("unable to parse ip address, error: %s", err) } @@ -93,7 +93,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { } err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}) if err != nil { - return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.serverIP, config.serverPort, err) + return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.ServerIP, config.ServerPort, err) } var ( @@ -101,24 +101,24 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { matchDomains []string domainsInput []systemdDbusLinkDomainsInput ) - for _, dConf := range config.domains { - if dConf.disabled { + for _, dConf := range config.Domains { + if dConf.Disabled { continue } domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ - Domain: dns.Fqdn(dConf.domain), - MatchOnly: dConf.matchOnly, + Domain: dns.Fqdn(dConf.Domain), + MatchOnly: dConf.MatchOnly, }) - if dConf.matchOnly { - matchDomains = append(matchDomains, dConf.domain) + if dConf.MatchOnly { + matchDomains = append(matchDomains, dConf.Domain) continue } - searchDomains = append(searchDomains, dConf.domain) + searchDomains = append(searchDomains, dConf.Domain) } - if config.routeAll { - log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + if config.RouteAll { + log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true) if err != nil { return fmt.Errorf("setting link as default dns router, failed with error: %s", err) @@ -129,7 +129,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { }) s.routingAll = true } else if s.routingAll { - log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index d19ac265e..a716e0f24 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "runtime" "sync" "sync/atomic" "time" @@ -21,10 +22,15 @@ const ( ) type upstreamClient interface { - ExchangeContext(ctx context.Context, m *dns.Msg, a string) (r *dns.Msg, rtt time.Duration, err error) + exchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) } -type upstreamResolver struct { +type UpstreamResolver interface { + serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error) + upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) +} + +type upstreamResolverBase struct { ctx context.Context cancel context.CancelFunc upstreamClient upstreamClient @@ -40,25 +46,25 @@ type upstreamResolver struct { reactivate func() } -func newUpstreamResolver(parentCTX context.Context) *upstreamResolver { +func newUpstreamResolverBase(parentCTX context.Context) *upstreamResolverBase { ctx, cancel := context.WithCancel(parentCTX) - return &upstreamResolver{ + + return &upstreamResolverBase{ ctx: ctx, cancel: cancel, - upstreamClient: &dns.Client{}, upstreamTimeout: upstreamTimeout, reactivatePeriod: reactivatePeriod, failsTillDeact: failsTillDeact, } } -func (u *upstreamResolver) stop() { +func (u *upstreamResolverBase) stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() } // ServeDNS handles a DNS request -func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { +func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { defer u.checkUpstreamFails() log.WithField("question", r.Question[0]).Trace("received an upstream question") @@ -70,10 +76,8 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } for _, upstream := range u.upstreamServers { - ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) - rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream) - cancel() + rm, t, err := u.upstreamClient.exchange(upstream, r) if err != nil { if err == context.DeadlineExceeded || isTimeout(err) { @@ -83,7 +87,19 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } u.failsCount.Add(1) log.WithError(err).WithField("upstream", upstream). - Error("got an error while querying the upstream") + Error("got other error while querying the upstream") + return + } + + if rm == nil { + log.WithError(err).WithField("upstream", upstream). + Warn("no response from upstream") + return + } + // those checks need to be independent of each other due to memory address issues + if !rm.Response { + log.WithError(err).WithField("upstream", upstream). + Warn("no response from upstream") return } @@ -106,7 +122,7 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If fails count is greater that failsTillDeact, upstream resolving // will be disabled for reactivatePeriod, after that time period fails counter // will be reset and upstream will be reactivated. -func (u *upstreamResolver) checkUpstreamFails() { +func (u *upstreamResolverBase) checkUpstreamFails() { u.mutex.Lock() defer u.mutex.Unlock() @@ -118,15 +134,18 @@ func (u *upstreamResolver) checkUpstreamFails() { case <-u.ctx.Done(): return default: - log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) - u.deactivate() - u.disabled = true - go u.waitUntilResponse() + // todo test the deactivation logic, it seems to affect the client + if runtime.GOOS != "ios" { + log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod) + u.deactivate() + u.disabled = true + go u.waitUntilResponse() + } } } // waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response -func (u *upstreamResolver) waitUntilResponse() { +func (u *upstreamResolverBase) waitUntilResponse() { exponentialBackOff := &backoff.ExponentialBackOff{ InitialInterval: 500 * time.Millisecond, RandomizationFactor: 0.5, @@ -148,10 +167,7 @@ func (u *upstreamResolver) waitUntilResponse() { var err error for _, upstream := range u.upstreamServers { - ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) - _, _, err = u.upstreamClient.ExchangeContext(ctx, r, upstream) - - cancel() + _, _, err = u.upstreamClient.exchange(upstream, r) if err == nil { return nil diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go new file mode 100644 index 000000000..7283efa20 --- /dev/null +++ b/client/internal/dns/upstream_ios.go @@ -0,0 +1,93 @@ +//go:build ios + +package dns + +import ( + "context" + "net" + "syscall" + "time" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +type upstreamResolverIOS struct { + *upstreamResolverBase + lIP net.IP + lNet *net.IPNet + iIndex int +} + +func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverIOS, error) { + upstreamResolverBase := newUpstreamResolverBase(parentCTX) + + index, err := getInterfaceIndex(interfaceName) + if err != nil { + log.Debugf("unable to get interface index for %s: %s", interfaceName, err) + return nil, err + } + + ios := &upstreamResolverIOS{ + upstreamResolverBase: upstreamResolverBase, + lIP: ip, + lNet: net, + iIndex: index, + } + ios.upstreamClient = ios + + return ios, nil +} + +func (u *upstreamResolverIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { + client := &dns.Client{} + upstreamHost, _, err := net.SplitHostPort(upstream) + if err != nil { + log.Errorf("error while parsing upstream host: %s", err) + } + upstreamIP := net.ParseIP(upstreamHost) + if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { + log.Debugf("using private client to query upstream: %s", upstream) + client = u.getClientPrivate() + } + + return client.Exchange(r, upstream) +} + +// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface +// This method is needed for iOS +func (u *upstreamResolverIOS) getClientPrivate() *dns.Client { + dialer := &net.Dialer{ + LocalAddr: &net.UDPAddr{ + IP: u.lIP, + Port: 0, // Let the OS pick a free port + }, + Timeout: upstreamTimeout, + Control: func(network, address string, c syscall.RawConn) error { + var operr error + fn := func(s uintptr) { + operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex) + } + + if err := c.Control(fn); err != nil { + return err + } + + if operr != nil { + log.Errorf("error while setting socket option: %s", operr) + } + + return operr + }, + } + client := &dns.Client{ + Dialer: dialer, + } + return client +} + +func getInterfaceIndex(interfaceName string) (int, error) { + iface, err := net.InterfaceByName(interfaceName) + return iface.Index, err +} diff --git a/client/internal/dns/upstream_nonios.go b/client/internal/dns/upstream_nonios.go new file mode 100644 index 000000000..a146f3f98 --- /dev/null +++ b/client/internal/dns/upstream_nonios.go @@ -0,0 +1,32 @@ +//go:build !ios + +package dns + +import ( + "context" + "net" + "time" + + "github.com/miekg/dns" +) + +type upstreamResolverNonIOS struct { + *upstreamResolverBase +} + +func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net.IP, net *net.IPNet) (*upstreamResolverNonIOS, error) { + upstreamResolverBase := newUpstreamResolverBase(parentCTX) + nonIOS := &upstreamResolverNonIOS{ + upstreamResolverBase: upstreamResolverBase, + } + upstreamResolverBase.upstreamClient = nonIOS + return nonIOS, nil +} + +func (u *upstreamResolverNonIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { + upstreamExchangeClient := &dns.Client{} + ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) + rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream) + cancel() + return rm, t, err +} diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 0a5de0b18..d73e04ce0 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -2,6 +2,7 @@ package dns import ( "context" + "net" "strings" "testing" "time" @@ -49,15 +50,6 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { timeout: upstreamTimeout, responseShouldBeNil: true, }, - //{ - // name: "Should Resolve CNAME Record", - // inputMSG: new(dns.Msg).SetQuestion("one.one.one.one", dns.TypeCNAME), - //}, - //{ - // name: "Should Not Write When Not Found A Record", - // inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA), - // responseShouldBeNil: true, - //}, } // should resolve if first upstream times out // should not write when both fails @@ -66,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) - resolver := newUpstreamResolver(ctx) + resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}) resolver.upstreamServers = testCase.InputServers resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { @@ -114,12 +106,12 @@ type mockUpstreamResolver struct { } // ExchangeContext mock implementation of ExchangeContext from upstreamResolver -func (c mockUpstreamResolver) ExchangeContext(_ context.Context, _ *dns.Msg, _ string) (r *dns.Msg, rtt time.Duration, err error) { +func (c mockUpstreamResolver) exchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) { return c.r, c.rtt, c.err } func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { - resolver := &upstreamResolver{ + resolver := &upstreamResolverBase{ ctx: context.TODO(), upstreamClient: &mockUpstreamResolver{ err: nil, @@ -156,7 +148,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { } if !resolver.disabled { - t.Errorf("resolver should be disabled") + t.Errorf("resolver should be Disabled") return } diff --git a/client/internal/engine.go b/client/internal/engine.go index c525601b4..43d37e4b7 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -197,7 +197,8 @@ func (e *Engine) Start() error { var routes []*route.Route - if runtime.GOOS == "android" { + switch runtime.GOOS { + case "android": var dnsConfig *nbdns.Config routes, dnsConfig, err = e.readInitialSettings() if err != nil { @@ -207,25 +208,34 @@ func (e *Engine) Start() error { e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses, *dnsConfig, e.mobileDep.NetworkChangeListener) go e.mobileDep.DnsReadyListener.OnReady() } - } else if e.dnsServer == nil { - // todo fix custom address - e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress) - if err != nil { - e.close() - return err + case "ios": + if e.dnsServer == nil { + e.dnsServer = dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager) + } + default: + if e.dnsServer == nil { + e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress) + if err != nil { + e.close() + return err + } } } e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) - if runtime.GOOS == "android" { - err = e.wgInterface.CreateOnMobile(iface.MobileIFaceArguments{ + switch runtime.GOOS { + case "android": + err = e.wgInterface.CreateOnAndroid(iface.MobileIFaceArguments{ Routes: e.routeManager.InitialRouteRange(), Dns: e.dnsServer.DnsIP(), SearchDomains: e.dnsServer.SearchDomains(), }) - } else { + case "ios": + e.mobileDep.NetworkChangeListener.SetInterfaceIP(wgAddr) + err = e.wgInterface.CreateOniOS(e.mobileDep.FileDescriptor) + default: err = e.wgInterface.Create() } if err != nil { @@ -480,7 +490,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { } // start SSH server if it wasn't running if isNil(e.sshServer) { - //nil sshServer means it has not yet been started + // nil sshServer means it has not yet been started var err error e.sshServer, err = e.sshServerFunc(e.config.SSHKey, fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)) diff --git a/client/internal/listener/network_change.go b/client/internal/listener/network_change.go index ff9cb11f5..08bf5fd52 100644 --- a/client/internal/listener/network_change.go +++ b/client/internal/listener/network_change.go @@ -3,5 +3,6 @@ package listener // NetworkChangeListener is a callback interface for mobile system type NetworkChangeListener interface { // OnNetworkChanged invoke when network settings has been changed - OnNetworkChanged() + OnNetworkChanged(string) + SetInterfaceIP(string) } diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index a2bbf2473..1a2a4c2b2 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -14,4 +14,6 @@ type MobileDependency struct { NetworkChangeListener listener.NetworkChangeListener HostDNSAddresses []string DnsReadyListener dns.ReadyListener + DnsManager dns.IosDnsManager + FileDescriptor int32 } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 442f5fe27..db37c0528 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "runtime" "strings" "sync" "time" @@ -225,6 +226,10 @@ func (conn *Conn) candidateTypes() []ice.CandidateType { if hasICEForceRelayConn() { return []ice.CandidateType{ice.CandidateTypeRelay} } + // TODO: remove this once we have refactored userspace proxy into the bind package + if runtime.GOOS == "ios" { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} + } return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} } @@ -464,7 +469,7 @@ func (conn *Conn) cleanup() error { err := conn.statusRecorder.UpdatePeerState(peerState) if err != nil { // pretty common error because by that time Engine can already remove the peer and status won't be available. - //todo rethink status updates + // todo rethink status updates log.Debugf("error while updating peer's %s state, err: %v", conn.config.Key, err) } diff --git a/client/internal/routemanager/notifier.go b/client/internal/routemanager/notifier.go index 96ca95aa2..e27d08db5 100644 --- a/client/internal/routemanager/notifier.go +++ b/client/internal/routemanager/notifier.go @@ -2,6 +2,7 @@ package routemanager import ( "sort" + "strings" "sync" "github.com/netbirdio/netbird/client/internal/listener" @@ -50,9 +51,6 @@ func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) { n.routeRangers = newNets - if !n.hasDiff(n.initialRouteRangers, newNets) { - return - } n.notify() } @@ -64,7 +62,7 @@ func (n *notifier) notify() { } go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged() + l.OnNetworkChanged(strings.Join(n.routeRangers, ",")) }(n.listener) } diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index 7eafabd77..1918c7f6f 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -1,3 +1,5 @@ +//go:build android + package routemanager import ( diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go new file mode 100644 index 000000000..aae0f8dc8 --- /dev/null +++ b/client/internal/routemanager/systemops_ios.go @@ -0,0 +1,15 @@ +//go:build ios + +package routemanager + +import ( + "net/netip" +) + +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { + return nil +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { + return nil +} diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go index b229a580f..11247c7dc 100644 --- a/client/internal/routemanager/systemops_nonandroid.go +++ b/client/internal/routemanager/systemops_nonandroid.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build !android && !ios package routemanager diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go new file mode 100644 index 000000000..7c2525901 --- /dev/null +++ b/client/ios/NetBirdSDK/client.go @@ -0,0 +1,224 @@ +package NetBirdSDK + +import ( + "context" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/auth" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/formatter" +) + +// ConnectionListener export internal Listener for mobile +type ConnectionListener interface { + peer.Listener +} + +// RouteListener export internal RouteListener for mobile +type NetworkChangeListener interface { + listener.NetworkChangeListener +} + +// DnsManager export internal dns Manager for mobile +type DnsManager interface { + dns.IosDnsManager +} + +// CustomLogger export internal CustomLogger for mobile +type CustomLogger interface { + Debug(message string) + Info(message string) + Error(message string) +} + +func init() { + formatter.SetLogcatFormatter(log.StandardLogger()) +} + +// Client struct manage the life circle of background service +type Client struct { + cfgFile string + recorder *peer.Status + ctxCancel context.CancelFunc + ctxCancelLock *sync.Mutex + deviceName string + osName string + osVersion string + networkChangeListener listener.NetworkChangeListener + onHostDnsFn func([]string) + dnsManager dns.IosDnsManager + loginComplete bool +} + +// NewClient instantiate a new Client +func NewClient(cfgFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { + return &Client{ + cfgFile: cfgFile, + deviceName: deviceName, + osName: osName, + osVersion: osVersion, + recorder: peer.NewRecorder(""), + ctxCancelLock: &sync.Mutex{}, + networkChangeListener: networkChangeListener, + dnsManager: dnsManager, + } +} + +// Run start the internal client. It is a blocker function +func (c *Client) Run(fd int32, interfaceName string) error { + log.Infof("Starting NetBird client") + log.Debugf("Tunnel uses interface: %s", interfaceName) + cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ + ConfigPath: c.cfgFile, + }) + if err != nil { + return err + } + c.recorder.UpdateManagementAddress(cfg.ManagementURL.String()) + + var ctx context.Context + //nolint + ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName) + //nolint + ctxWithValues = context.WithValue(ctxWithValues, system.OsNameCtxKey, c.osName) + //nolint + ctxWithValues = context.WithValue(ctxWithValues, system.OsVersionCtxKey, c.osVersion) + c.ctxCancelLock.Lock() + ctx, c.ctxCancel = context.WithCancel(ctxWithValues) + defer c.ctxCancel() + c.ctxCancelLock.Unlock() + + auth := NewAuthWithConfig(ctx, cfg) + err = auth.Login() + if err != nil { + return err + } + + log.Infof("Auth successful") + // todo do not throw error in case of cancelled context + ctx = internal.CtxInitState(ctx) + c.onHostDnsFn = func([]string) {} + cfg.WgIface = interfaceName + return internal.RunClientiOS(ctx, cfg, c.recorder, fd, c.networkChangeListener, c.dnsManager) +} + +// Stop the internal client and free the resources +func (c *Client) Stop() { + c.ctxCancelLock.Lock() + defer c.ctxCancelLock.Unlock() + if c.ctxCancel == nil { + return + } + + c.ctxCancel() +} + +// ÏSetTraceLogLevel configure the logger to trace level +func (c *Client) SetTraceLogLevel() { + log.SetLevel(log.TraceLevel) +} + +// getStatusDetails return with the list of the PeerInfos +func (c *Client) GetStatusDetails() *StatusDetails { + + fullStatus := c.recorder.GetFullStatus() + + peerInfos := make([]PeerInfo, len(fullStatus.Peers)) + for n, p := range fullStatus.Peers { + pi := PeerInfo{ + p.IP, + p.FQDN, + p.ConnStatus.String(), + } + peerInfos[n] = pi + } + return &StatusDetails{items: peerInfos, fqdn: fullStatus.LocalPeerState.FQDN, ip: fullStatus.LocalPeerState.IP} +} + +// SetConnectionListener set the network connection listener +func (c *Client) SetConnectionListener(listener ConnectionListener) { + c.recorder.SetConnectionListener(listener) +} + +// RemoveConnectionListener remove connection listener +func (c *Client) RemoveConnectionListener() { + c.recorder.RemoveConnectionListener() +} + +func (c *Client) IsLoginRequired() bool { + var ctx context.Context + //nolint + ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName) + //nolint + ctxWithValues = context.WithValue(ctxWithValues, system.OsNameCtxKey, c.osName) + //nolint + ctxWithValues = context.WithValue(ctxWithValues, system.OsVersionCtxKey, c.osVersion) + c.ctxCancelLock.Lock() + defer c.ctxCancelLock.Unlock() + ctx, c.ctxCancel = context.WithCancel(ctxWithValues) + + cfg, _ := internal.UpdateOrCreateConfig(internal.ConfigInput{ + ConfigPath: c.cfgFile, + }) + + needsLogin, _ := internal.IsLoginRequired(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg.SSHKey) + return needsLogin +} + +func (c *Client) LoginForMobile() string { + var ctx context.Context + //nolint + ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName) + //nolint + ctxWithValues = context.WithValue(ctxWithValues, system.OsNameCtxKey, c.osName) + //nolint + ctxWithValues = context.WithValue(ctxWithValues, system.OsVersionCtxKey, c.osVersion) + c.ctxCancelLock.Lock() + defer c.ctxCancelLock.Unlock() + ctx, c.ctxCancel = context.WithCancel(ctxWithValues) + + cfg, _ := internal.UpdateOrCreateConfig(internal.ConfigInput{ + ConfigPath: c.cfgFile, + }) + + oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false) + if err != nil { + return err.Error() + } + + flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO()) + if err != nil { + return err.Error() + } + + // This could cause a potential race condition with loading the extension which need to be handled on swift side + go func() { + waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second + waitCTX, cancel := context.WithTimeout(ctx, waitTimeout) + defer cancel() + tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + if err != nil { + return + } + jwtToken := tokenInfo.GetTokenToUse() + _ = internal.Login(ctx, cfg, "", jwtToken) + c.loginComplete = true + }() + + return flowInfo.VerificationURIComplete +} + +func (c *Client) IsLoginComplete() bool { + return c.loginComplete +} + +func (c *Client) ClearLoginComplete() { + c.loginComplete = false +} diff --git a/client/ios/NetBirdSDK/gomobile.go b/client/ios/NetBirdSDK/gomobile.go new file mode 100644 index 000000000..9eadd6a7f --- /dev/null +++ b/client/ios/NetBirdSDK/gomobile.go @@ -0,0 +1,5 @@ +package NetBirdSDK + +import _ "golang.org/x/mobile/bind" + +// to keep our CI/CD that checks go.mod and go.sum files happy, we need to import the package above diff --git a/client/ios/NetBirdSDK/logger.go b/client/ios/NetBirdSDK/logger.go new file mode 100644 index 000000000..f1ad1b9f6 --- /dev/null +++ b/client/ios/NetBirdSDK/logger.go @@ -0,0 +1,10 @@ +package NetBirdSDK + +import ( + "github.com/netbirdio/netbird/util" +) + +// InitializeLog initializes the log file. +func InitializeLog(logLevel string, filePath string) error { + return util.InitLog(logLevel, filePath) +} diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go new file mode 100644 index 000000000..257329e5c --- /dev/null +++ b/client/ios/NetBirdSDK/login.go @@ -0,0 +1,159 @@ +package NetBirdSDK + +import ( + "context" + "fmt" + "time" + + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/cmd" + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/system" +) + +// SSOListener is async listener for mobile framework +type SSOListener interface { + OnSuccess(bool) + OnError(error) +} + +// ErrListener is async listener for mobile framework +type ErrListener interface { + OnSuccess() + OnError(error) +} + +// URLOpener it is a callback interface. The Open function will be triggered if +// the backend want to show an url for the user +type URLOpener interface { + Open(string) +} + +// Auth can register or login new client +type Auth struct { + ctx context.Context + config *internal.Config + cfgPath string +} + +// NewAuth instantiate Auth struct and validate the management URL +func NewAuth(cfgPath string, mgmURL string) (*Auth, error) { + inputCfg := internal.ConfigInput{ + ManagementURL: mgmURL, + } + + cfg, err := internal.CreateInMemoryConfig(inputCfg) + if err != nil { + return nil, err + } + + return &Auth{ + ctx: context.Background(), + config: cfg, + cfgPath: cfgPath, + }, nil +} + +// NewAuthWithConfig instantiate Auth based on existing config +func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth { + return &Auth{ + ctx: ctx, + config: config, + } +} + +// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info. +// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO +// is not supported and returns false without saving the configuration. For other errors return false. +func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { + supportsSSO := true + err := a.withBackOff(a.ctx, func() (err error) { + _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { + _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound { + supportsSSO = false + err = nil + } + + return err + } + + return err + }) + + if !supportsSSO { + return false, nil + } + + if err != nil { + return false, fmt.Errorf("backoff cycle failed: %v", err) + } + + err = internal.WriteOutConfig(a.cfgPath, a.config) + return true, err +} + +// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key. +func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { + //nolint + ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) + + err := a.withBackOff(a.ctx, func() error { + backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "") + if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) { + // we got an answer from management, exit backoff earlier + return backoff.Permanent(backoffErr) + } + return backoffErr + }) + if err != nil { + return fmt.Errorf("backoff cycle failed: %v", err) + } + + return internal.WriteOutConfig(a.cfgPath, a.config) +} + +func (a *Auth) Login() error { + var needsLogin bool + + // check if we need to generate JWT token + err := a.withBackOff(a.ctx, func() (err error) { + needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey) + return + }) + if err != nil { + return fmt.Errorf("backoff cycle failed: %v", err) + } + + jwtToken := "" + if needsLogin { + return fmt.Errorf("Not authenticated") + } + + err = a.withBackOff(a.ctx, func() error { + err := internal.Login(a.ctx, a.config, "", jwtToken) + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { + return nil + } + return err + }) + if err != nil { + return fmt.Errorf("backoff cycle failed: %v", err) + } + + return nil +} + +func (a *Auth) withBackOff(ctx context.Context, bf func() error) error { + return backoff.RetryNotify( + bf, + backoff.WithContext(cmd.CLIBackOffSettings, ctx), + func(err error, duration time.Duration) { + log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err) + }) +} diff --git a/client/ios/NetBirdSDK/peer_notifier.go b/client/ios/NetBirdSDK/peer_notifier.go new file mode 100644 index 000000000..e52008d9f --- /dev/null +++ b/client/ios/NetBirdSDK/peer_notifier.go @@ -0,0 +1,50 @@ +package NetBirdSDK + +// PeerInfo describe information about the peers. It designed for the UI usage +type PeerInfo struct { + IP string + FQDN string + ConnStatus string // Todo replace to enum +} + +// PeerInfoCollection made for Java layer to get non default types as collection +type PeerInfoCollection interface { + Add(s string) PeerInfoCollection + Get(i int) string + Size() int + GetFQDN() string + GetIP() string +} + +// StatusDetails is the implementation of the PeerInfoCollection +type StatusDetails struct { + items []PeerInfo + fqdn string + ip string +} + +// Add new PeerInfo to the collection +func (array StatusDetails) Add(s PeerInfo) StatusDetails { + array.items = append(array.items, s) + return array +} + +// Get return an element of the collection +func (array StatusDetails) Get(i int) *PeerInfo { + return &array.items[i] +} + +// Size return with the size of the collection +func (array StatusDetails) Size() int { + return len(array.items) +} + +// GetFQDN return with the FQDN of the local peer +func (array StatusDetails) GetFQDN() string { + return array.fqdn +} + +// GetIP return with the IP of the local peer +func (array StatusDetails) GetIP() string { + return array.ip +} diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go new file mode 100644 index 000000000..297d53ff0 --- /dev/null +++ b/client/ios/NetBirdSDK/preferences.go @@ -0,0 +1,78 @@ +package NetBirdSDK + +import ( + "github.com/netbirdio/netbird/client/internal" +) + +// Preferences export a subset of the internal config for gomobile +type Preferences struct { + configInput internal.ConfigInput +} + +// NewPreferences create new Preferences instance +func NewPreferences(configPath string) *Preferences { + ci := internal.ConfigInput{ + ConfigPath: configPath, + } + return &Preferences{ci} +} + +// GetManagementURL read url from config file +func (p *Preferences) GetManagementURL() (string, error) { + if p.configInput.ManagementURL != "" { + return p.configInput.ManagementURL, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return "", err + } + return cfg.ManagementURL.String(), err +} + +// SetManagementURL store the given url and wait for commit +func (p *Preferences) SetManagementURL(url string) { + p.configInput.ManagementURL = url +} + +// GetAdminURL read url from config file +func (p *Preferences) GetAdminURL() (string, error) { + if p.configInput.AdminURL != "" { + return p.configInput.AdminURL, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return "", err + } + return cfg.AdminURL.String(), err +} + +// SetAdminURL store the given url and wait for commit +func (p *Preferences) SetAdminURL(url string) { + p.configInput.AdminURL = url +} + +// GetPreSharedKey read preshared key from config file +func (p *Preferences) GetPreSharedKey() (string, error) { + if p.configInput.PreSharedKey != nil { + return *p.configInput.PreSharedKey, nil + } + + cfg, err := internal.ReadConfig(p.configInput.ConfigPath) + if err != nil { + return "", err + } + return cfg.PreSharedKey, err +} + +// SetPreSharedKey store the given key and wait for commit +func (p *Preferences) SetPreSharedKey(key string) { + p.configInput.PreSharedKey = &key +} + +// Commit write out the changes into config file +func (p *Preferences) Commit() error { + _, err := internal.UpdateOrCreateConfig(p.configInput) + return err +} diff --git a/client/ios/NetBirdSDK/preferences_test.go b/client/ios/NetBirdSDK/preferences_test.go new file mode 100644 index 000000000..aa6a475ae --- /dev/null +++ b/client/ios/NetBirdSDK/preferences_test.go @@ -0,0 +1,120 @@ +package NetBirdSDK + +import ( + "path/filepath" + "testing" + + "github.com/netbirdio/netbird/client/internal" +) + +func TestPreferences_DefaultValues(t *testing.T) { + cfgFile := filepath.Join(t.TempDir(), "netbird.json") + p := NewPreferences(cfgFile) + defaultVar, err := p.GetAdminURL() + if err != nil { + t.Fatalf("failed to read default value: %s", err) + } + + if defaultVar != internal.DefaultAdminURL { + t.Errorf("invalid default admin url: %s", defaultVar) + } + + defaultVar, err = p.GetManagementURL() + if err != nil { + t.Fatalf("failed to read default management URL: %s", err) + } + + if defaultVar != internal.DefaultManagementURL { + t.Errorf("invalid default management url: %s", defaultVar) + } + + var preSharedKey string + preSharedKey, err = p.GetPreSharedKey() + if err != nil { + t.Fatalf("failed to read default preshared key: %s", err) + } + + if preSharedKey != "" { + t.Errorf("invalid preshared key: %s", preSharedKey) + } +} + +func TestPreferences_ReadUncommitedValues(t *testing.T) { + exampleString := "exampleString" + cfgFile := filepath.Join(t.TempDir(), "netbird.json") + p := NewPreferences(cfgFile) + + p.SetAdminURL(exampleString) + resp, err := p.GetAdminURL() + if err != nil { + t.Fatalf("failed to read admin url: %s", err) + } + + if resp != exampleString { + t.Errorf("unexpected admin url: %s", resp) + } + + p.SetManagementURL(exampleString) + resp, err = p.GetManagementURL() + if err != nil { + t.Fatalf("failed to read management url: %s", err) + } + + if resp != exampleString { + t.Errorf("unexpected management url: %s", resp) + } + + p.SetPreSharedKey(exampleString) + resp, err = p.GetPreSharedKey() + if err != nil { + t.Fatalf("failed to read preshared key: %s", err) + } + + if resp != exampleString { + t.Errorf("unexpected preshared key: %s", resp) + } +} + +func TestPreferences_Commit(t *testing.T) { + exampleURL := "https://myurl.com:443" + examplePresharedKey := "topsecret" + cfgFile := filepath.Join(t.TempDir(), "netbird.json") + p := NewPreferences(cfgFile) + + p.SetAdminURL(exampleURL) + p.SetManagementURL(exampleURL) + p.SetPreSharedKey(examplePresharedKey) + + err := p.Commit() + if err != nil { + t.Fatalf("failed to save changes: %s", err) + } + + p = NewPreferences(cfgFile) + resp, err := p.GetAdminURL() + if err != nil { + t.Fatalf("failed to read admin url: %s", err) + } + + if resp != exampleURL { + t.Errorf("unexpected admin url: %s", resp) + } + + resp, err = p.GetManagementURL() + if err != nil { + t.Fatalf("failed to read management url: %s", err) + } + + if resp != exampleURL { + t.Errorf("unexpected management url: %s", resp) + } + + resp, err = p.GetPreSharedKey() + if err != nil { + t.Fatalf("failed to read preshared key: %s", err) + } + + if resp != examplePresharedKey { + t.Errorf("unexpected preshared key: %s", resp) + } +} diff --git a/client/system/info.go b/client/system/info.go index a495ed1e9..2d5b7192e 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -12,6 +12,12 @@ import ( // DeviceNameCtxKey context key for device name const DeviceNameCtxKey = "deviceName" +// OsVersionCtxKey context key for operating system version +const OsVersionCtxKey = "OsVersion" + +// OsNameCtxKey context key for operating system name +const OsNameCtxKey = "OsName" + // Info is an object that contains machine information // Most of the code is taken from https://github.com/matishsiao/goInfo type Info struct { diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go index 621a82691..5ae2b4fc6 100644 --- a/client/system/info_darwin.go +++ b/client/system/info_darwin.go @@ -1,3 +1,6 @@ +//go:build !ios +// +build !ios + package system import ( diff --git a/client/system/info_ios.go b/client/system/info_ios.go new file mode 100644 index 000000000..c0e51ec60 --- /dev/null +++ b/client/system/info_ios.go @@ -0,0 +1,44 @@ +//go:build ios +// +build ios + +package system + +import ( + "context" + "runtime" + + "github.com/netbirdio/netbird/version" +) + +// GetInfo retrieves and parses the system information +func GetInfo(ctx context.Context) *Info { + + // Convert fixed-size byte arrays to Go strings + sysName := extractOsName(ctx, "sysName") + swVersion := extractOsVersion(ctx, "swVersion") + + gio := &Info{Kernel: sysName, OSVersion: swVersion, Core: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} + gio.Hostname = extractDeviceName(ctx, "hostname") + gio.WiretrusteeVersion = version.NetbirdVersion() + gio.UIVersion = extractUserAgent(ctx) + + return gio +} + +// extractOsVersion extracts operating system version from context or returns the default +func extractOsVersion(ctx context.Context, defaultName string) string { + v, ok := ctx.Value(OsVersionCtxKey).(string) + if !ok { + return defaultName + } + return v +} + +// extractOsName extracts operating system name from context or returns the default +func extractOsName(ctx context.Context, defaultName string) string { + v, ok := ctx.Value(OsNameCtxKey).(string) + if !ok { + return defaultName + } + return v +} diff --git a/iface/iface_android.go b/iface/iface_android.go index 208eff7a8..6afa2d580 100644 --- a/iface/iface_android.go +++ b/iface/iface_android.go @@ -28,14 +28,20 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter return wgIFace, nil } -// CreateOnMobile creates a new Wireguard interface, sets a given IP and brings it up. +// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. -func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error { +func (w *WGIface) CreateOnAndroid(mIFaceArgs MobileIFaceArguments) error { w.mu.Lock() defer w.mu.Unlock() return w.tun.Create(mIFaceArgs) } +// CreateOniOS creates a new Wireguard interface, sets a given IP and brings it up. +// Will reuse an existing one. +func (w *WGIface) CreateOniOS(tunFd int32) error { + return fmt.Errorf("this function has not implemented on mobile") +} + // Create this function make sense on mobile only func (w *WGIface) Create() error { return fmt.Errorf("this function has not implemented on mobile") diff --git a/iface/iface_ios.go b/iface/iface_ios.go new file mode 100644 index 000000000..dd68d7792 --- /dev/null +++ b/iface/iface_ios.go @@ -0,0 +1,51 @@ +//go:build ios +// +build ios + +package iface + +import ( + "fmt" + "sync" + + "github.com/pion/transport/v2" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { + wgIFace := &WGIface{ + mu: sync.Mutex{}, + } + + wgAddress, err := parseWGAddress(address) + if err != nil { + return wgIFace, err + } + + tun := newTunDevice(ifaceName, wgAddress, mtu, tunAdapter, transportNet) + wgIFace.tun = tun + + wgIFace.configurer = newWGConfigurer(tun) + + wgIFace.userspaceBind = !WireGuardModuleIsLoaded() + + return wgIFace, nil +} + +// CreateOniOS creates a new Wireguard interface, sets a given IP and brings it up. +// Will reuse an existing one. +func (w *WGIface) CreateOniOS(tunFd int32) error { + w.mu.Lock() + defer w.mu.Unlock() + return w.tun.Create(tunFd) +} + +// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. +// Will reuse an existing one. +func (w *WGIface) CreateOnAndroid(mIFaceArgs MobileIFaceArguments) error { + return fmt.Errorf("this function has not implemented on mobile") +} + +// Create this function make sense on mobile only +func (w *WGIface) Create() error { + return fmt.Errorf("this function has not implemented on mobile") +} diff --git a/iface/iface_nonandroid.go b/iface/iface_nonandroid.go index da4ef13fd..80a9ab769 100644 --- a/iface/iface_nonandroid.go +++ b/iface/iface_nonandroid.go @@ -1,4 +1,5 @@ -//go:build !android +//go:build !android && !ios +// +build !android,!ios package iface @@ -27,8 +28,13 @@ func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter return wgIFace, nil } -// CreateOnMobile this function make sense on mobile only -func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error { +// CreateOnAndroid this function make sense on mobile only +func (w *WGIface) CreateOnAndroid(mIFaceArgs MobileIFaceArguments) error { + return fmt.Errorf("this function has not implemented on non mobile") +} + +// CreateOniOS this function make sense on mobile only +func (w *WGIface) CreateOniOS(tunFd int32) error { return fmt.Errorf("this function has not implemented on non mobile") } diff --git a/iface/ipc_parser_android.go b/iface/ipc_parser_mobile.go similarity index 96% rename from iface/ipc_parser_android.go rename to iface/ipc_parser_mobile.go index e1dd66856..7d4af8139 100644 --- a/iface/ipc_parser_android.go +++ b/iface/ipc_parser_mobile.go @@ -1,3 +1,6 @@ +//go:build android || ios +// +build android ios + package iface import ( diff --git a/iface/tun_android.go b/iface/tun_android.go index 30d86e7e2..e938dc57b 100644 --- a/iface/tun_android.go +++ b/iface/tun_android.go @@ -1,3 +1,6 @@ +//go:build android +// +build android + package iface import ( @@ -56,7 +59,7 @@ func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error { t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode - //t.device.DisableSomeRoamingForBrokenMobileSemantics() + // t.device.DisableSomeRoamingForBrokenMobileSemantics() err = t.device.Up() if err != nil { diff --git a/iface/tun_darwin.go b/iface/tun_darwin.go index a4ab2b4b1..6e917e374 100644 --- a/iface/tun_darwin.go +++ b/iface/tun_darwin.go @@ -1,3 +1,6 @@ +//go:build !ios +// +build !ios + package iface import ( diff --git a/iface/tun_ios.go b/iface/tun_ios.go new file mode 100644 index 000000000..7a9ce5622 --- /dev/null +++ b/iface/tun_ios.go @@ -0,0 +1,101 @@ +//go:build ios +// +build ios + +package iface + +import ( + "os" + + "github.com/pion/transport/v2" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + + "github.com/netbirdio/netbird/iface/bind" +) + +type tunDevice struct { + address WGAddress + mtu int + tunAdapter TunAdapter + iceBind *bind.ICEBind + + fd int + name string + device *device.Device + wrapper *DeviceWrapper +} + +func newTunDevice(name string, address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice { + return &tunDevice{ + name: name, + address: address, + mtu: mtu, + tunAdapter: tunAdapter, + iceBind: bind.NewICEBind(transportNet), + } +} + +func (t *tunDevice) Create(tunFd int32) error { + log.Infof("create tun interface") + + dupTunFd, err := unix.Dup(int(tunFd)) + if err != nil { + log.Errorf("Unable to dup tun fd: %v", err) + return err + } + + err = unix.SetNonblock(dupTunFd, true) + if err != nil { + log.Errorf("Unable to set tun fd as non blocking: %v", err) + unix.Close(dupTunFd) + return err + } + tun, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0) + if err != nil { + log.Errorf("Unable to create new tun device from fd: %v", err) + unix.Close(dupTunFd) + return err + } + + t.wrapper = newDeviceWrapper(tun) + log.Debug("Attaching to interface") + t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) + // without this property mobile devices can discover remote endpoints if the configured one was wrong. + // this helps with support for the older NetBird clients that had a hardcoded direct mode + // t.device.DisableSomeRoamingForBrokenMobileSemantics() + + err = t.device.Up() + if err != nil { + t.device.Close() + return err + } + log.Debugf("device is ready to use: %s", t.name) + return nil +} + +func (t *tunDevice) Device() *device.Device { + return t.device +} + +func (t *tunDevice) DeviceName() string { + return t.name +} + +func (t *tunDevice) WgAddress() WGAddress { + return t.address +} + +func (t *tunDevice) UpdateAddr(addr WGAddress) error { + // todo implement + return nil +} + +func (t *tunDevice) Close() (err error) { + if t.device != nil { + t.device.Close() + } + + return +} diff --git a/iface/tun_unix.go b/iface/tun_unix.go index f923362a4..627814fc7 100644 --- a/iface/tun_unix.go +++ b/iface/tun_unix.go @@ -1,4 +1,4 @@ -//go:build (linux || darwin) && !android +//go:build (linux || darwin) && !android && !ios package iface diff --git a/iface/wg_configurer_android.go b/iface/wg_configurer_mobile.go similarity index 60% rename from iface/wg_configurer_android.go rename to iface/wg_configurer_mobile.go index 9328467a6..7f6e5595d 100644 --- a/iface/wg_configurer_android.go +++ b/iface/wg_configurer_mobile.go @@ -1,12 +1,17 @@ +//go:build ios || android +// +build ios android + package iface import ( + "encoding/hex" "errors" + "fmt" "net" + "strings" "time" log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -42,7 +47,7 @@ func (c *wGConfigurer) configureInterface(privateKey string, port int) error { } func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { - //parse allowed ips + // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { return err @@ -109,6 +114,52 @@ func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error { return c.tunDevice.Device().IpcSet(toWgUserspaceString(config)) } -func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { - return errFuncNotImplemented +func (c *wGConfigurer) removeAllowedIP(peerKey string, ip string) error { + ipc, err := c.tunDevice.Device().IpcGet() + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + hexKey := hex.EncodeToString(peerKeyParsed[:]) + + lines := strings.Split(ipc, "\n") + + output := "" + foundPeer := false + removedAllowedIP := false + for _, line := range lines { + line = strings.TrimSpace(line) + + // If we're within the details of the found peer and encounter another public key, + // this means we're starting another peer's details. So, reset the flag. + if strings.HasPrefix(line, "public_key=") && foundPeer { + foundPeer = false + } + + // Identify the peer with the specific public key + if line == fmt.Sprintf("public_key=%s", hexKey) { + foundPeer = true + } + + // If we're within the details of the found peer and find the specific allowed IP, skip this line + if foundPeer && line == "allowed_ip="+ip { + removedAllowedIP = true + continue + } + + // Append the line to the output string + if strings.HasPrefix(line, "private_key=") || strings.HasPrefix(line, "listen_port=") || + strings.HasPrefix(line, "public_key=") || strings.HasPrefix(line, "preshared_key=") || + strings.HasPrefix(line, "endpoint=") || strings.HasPrefix(line, "persistent_keepalive_interval=") || + strings.HasPrefix(line, "allowed_ip=") { + output += line + "\n" + } + } + + if !removedAllowedIP { + return fmt.Errorf("allowedIP not found") + } else { + return c.tunDevice.Device().IpcSet(output) + } } diff --git a/iface/wg_configurer_nonandroid.go b/iface/wg_configurer_nonmobile.go similarity index 98% rename from iface/wg_configurer_nonandroid.go rename to iface/wg_configurer_nonmobile.go index 3d9aff7a9..c09dda9ad 100644 --- a/iface/wg_configurer_nonandroid.go +++ b/iface/wg_configurer_nonmobile.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build !android && !ios package iface @@ -44,7 +44,7 @@ func (c *wGConfigurer) configureInterface(privateKey string, port int) error { } func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { - //parse allowed ips + // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { return err