From c590518e0c94ca64b0e6e72d4e27ede70a43b6f5 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 7 May 2024 12:28:30 +0200 Subject: [PATCH] Feature/exit node Android (#1916) Support exit node on Android. With the protect socket function, we mark every connection that should be used out of VPN. --- client/android/client.go | 4 + client/internal/dns/hosts_dns_holder.go | 63 ++++++++++++++ client/internal/dns/server.go | 42 +++++----- client/internal/dns/upstream_android.go | 84 +++++++++++++++++++ ...upstream_nonios.go => upstream_general.go} | 11 +-- client/internal/dns/upstream_ios.go | 1 + client/internal/dns/upstream_test.go | 2 +- client/internal/routemanager/manager.go | 7 +- client/internal/routemanager/notifier.go | 30 +++++-- iface/tun_adapter.go | 1 + util/net/dialer_android.go | 25 ++++++ util/net/{dialer_mobile.go => dialer_ios.go} | 2 - .../{dialer_generic.go => dialer_nonios.go} | 4 +- util/net/dialer_nonlinux.go | 2 +- util/net/listener_android.go | 26 ++++++ .../{listener_mobile.go => listener_ios.go} | 2 +- ...listener_generic.go => listener_nonios.go} | 2 +- util/net/listener_nonlinux.go | 2 +- util/net/protectsocket_android.go | 14 ++++ 19 files changed, 275 insertions(+), 49 deletions(-) create mode 100644 client/internal/dns/hosts_dns_holder.go create mode 100644 client/internal/dns/upstream_android.go rename client/internal/dns/{upstream_nonios.go => upstream_general.go} (67%) create mode 100644 util/net/dialer_android.go rename util/net/{dialer_mobile.go => dialer_ios.go} (91%) rename util/net/{dialer_generic.go => dialer_nonios.go} (98%) create mode 100644 util/net/listener_android.go rename util/net/{listener_mobile.go => listener_ios.go} (85%) rename util/net/{listener_generic.go => listener_nonios.go} (99%) create mode 100644 util/net/protectsocket_android.go diff --git a/client/android/client.go b/client/android/client.go index 81d3c96e1..297a4d1bc 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -1,3 +1,5 @@ +//go:build android + package android import ( @@ -14,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/util/net" ) // ConnectionListener export internal Listener for mobile @@ -59,6 +62,7 @@ type Client struct { // NewClient instantiate a new Client func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { + net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) return &Client{ cfgFile: cfgFile, deviceName: deviceName, diff --git a/client/internal/dns/hosts_dns_holder.go b/client/internal/dns/hosts_dns_holder.go new file mode 100644 index 000000000..2601af9c8 --- /dev/null +++ b/client/internal/dns/hosts_dns_holder.go @@ -0,0 +1,63 @@ +package dns + +import ( + "fmt" + "net/netip" + "sync" + + log "github.com/sirupsen/logrus" +) + +type hostsDNSHolder struct { + unprotectedDNSList map[string]struct{} + mutex sync.RWMutex +} + +func newHostsDNSHolder() *hostsDNSHolder { + return &hostsDNSHolder{ + unprotectedDNSList: make(map[string]struct{}), + } +} + +func (h *hostsDNSHolder) set(list []string) { + h.mutex.Lock() + h.unprotectedDNSList = make(map[string]struct{}) + for _, dns := range list { + dnsAddr, err := h.normalizeAddress(dns) + if err != nil { + continue + } + h.unprotectedDNSList[dnsAddr] = struct{}{} + } + h.mutex.Unlock() +} + +func (h *hostsDNSHolder) get() map[string]struct{} { + h.mutex.RLock() + l := h.unprotectedDNSList + h.mutex.RUnlock() + return l +} + +//nolint:unused +func (h *hostsDNSHolder) isContain(upstream string) bool { + h.mutex.RLock() + defer h.mutex.RUnlock() + + _, ok := h.unprotectedDNSList[upstream] + return ok +} + +func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) { + a, err := netip.ParseAddr(addr) + if err != nil { + log.Errorf("invalid upstream IP address: %s, error: %s", addr, err) + return "", err + } + + if a.Is4() { + return fmt.Sprintf("%s:53", addr), nil + } else { + return fmt.Sprintf("[%s]:53", addr), nil + } +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index b9608b6f2..8f6e8b572 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "runtime" "strings" "sync" @@ -54,9 +55,8 @@ type DefaultServer struct { currentConfig HostDNSConfig // permanent related properties - permanent bool - hostsDnsList []string - hostsDnsListLock sync.Mutex + permanent bool + hostsDNSHolder *hostsDNSHolder // make sense on mobile only searchDomainNotifier *notifier @@ -113,8 +113,8 @@ func NewDefaultServerPermanentUpstream( ) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) + ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true - ds.hostsDnsList = hostsDnsList ds.addHostRootZone() ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort()) ds.searchDomainNotifier = newNotifier(ds.SearchDomains()) @@ -147,6 +147,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi }, wgInterface: wgInterface, statusRecorder: statusRecorder, + hostsDNSHolder: newHostsDNSHolder(), } return defaultServer @@ -202,10 +203,8 @@ func (s *DefaultServer) Stop() { // OnUpdatedHostDNSServer update the DNS servers addresses for root zones // It will be applied if the mgm server do not enforce DNS settings for root zone func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { - s.hostsDnsListLock.Lock() - defer s.hostsDnsListLock.Unlock() + s.hostsDNSHolder.set(hostsDnsList) - s.hostsDnsList = hostsDnsList _, ok := s.dnsMuxMap[nbdns.RootZone] if ok { log.Debugf("on new host DNS config but skip to apply it") @@ -374,6 +373,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam s.wgInterface.Address().IP, s.wgInterface.Address().Network, s.statusRecorder, + s.hostsDNSHolder, ) if err != nil { return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err) @@ -452,9 +452,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { _, found := muxUpdateMap[key] if !found { if !isContainRootUpdate && key == nbdns.RootZone { - s.hostsDnsListLock.Lock() s.addHostRootZone() - s.hostsDnsListLock.Unlock() existingHandler.stop() } else { existingHandler.stop() @@ -512,6 +510,7 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { removeIndex[nbdns.RootZone] = -1 s.currentConfig.RouteAll = false + s.service.DeregisterMux(nbdns.RootZone) } for i, item := range s.currentConfig.Domains { @@ -521,10 +520,15 @@ func (s *DefaultServer) upstreamCallbacks( removeIndex[item.Domain] = i } } + if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } + if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { + s.addHostRootZone() + } + s.updateNSState(nsGroup, err, false) } @@ -545,6 +549,9 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { s.currentConfig.RouteAll = true + if runtime.GOOS == "android" { + s.service.RegisterMux(nbdns.RootZone, handler) + } } if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") @@ -562,25 +569,16 @@ func (s *DefaultServer) addHostRootZone() { s.wgInterface.Address().IP, s.wgInterface.Address().Network, s.statusRecorder, + s.hostsDNSHolder, ) if err != nil { log.Errorf("unable to create a new upstream resolver, error: %v", err) return } - handler.upstreamServers = make([]string, len(s.hostsDnsList)) - for n, ua := range s.hostsDnsList { - a, err := netip.ParseAddr(ua) - if err != nil { - log.Errorf("invalid upstream IP address: %s, error: %s", ua, err) - continue - } - ipString := ua - if !a.Is4() { - ipString = fmt.Sprintf("[%s]", ua) - } - - handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString) + handler.upstreamServers = make([]string, 0) + for k := range s.hostsDNSHolder.get() { + handler.upstreamServers = append(handler.upstreamServers, k) } handler.deactivate = func(error) {} handler.reactivate = func() {} diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go new file mode 100644 index 000000000..36ea05e44 --- /dev/null +++ b/client/internal/dns/upstream_android.go @@ -0,0 +1,84 @@ +package dns + +import ( + "context" + "net" + "syscall" + "time" + + "github.com/miekg/dns" + + "github.com/netbirdio/netbird/client/internal/peer" + nbnet "github.com/netbirdio/netbird/util/net" +) + +type upstreamResolver struct { + *upstreamResolverBase + hostsDNSHolder *hostsDNSHolder +} + +// newUpstreamResolver in Android we need to distinguish the DNS servers to available through VPN or outside of VPN +// In case if the assigned DNS address is available only in the protected network then the resolver will time out at the +// first time, and we need to wait for a while to start to use again the proper DNS resolver. +func newUpstreamResolver( + ctx context.Context, + _ string, + _ net.IP, + _ *net.IPNet, + statusRecorder *peer.Status, + hostsDNSHolder *hostsDNSHolder, +) (*upstreamResolver, error) { + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) + c := &upstreamResolver{ + upstreamResolverBase: upstreamResolverBase, + hostsDNSHolder: hostsDNSHolder, + } + upstreamResolverBase.upstreamClient = c + return c, nil +} + +// exchange in case of Android if the upstream is a local resolver then we do not need to mark the socket as protected. +// In other case the DNS resolvation goes through the VPN, so we need to force to use the +func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { + if u.isLocalResolver(upstream) { + return u.exchangeWithoutVPN(ctx, upstream, r) + } else { + return u.exchangeWithinVPN(ctx, upstream, r) + } +} + +func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { + upstreamExchangeClient := &dns.Client{} + return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) +} + +// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN +func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { + timeout := upstreamTimeout + if deadline, ok := ctx.Deadline(); ok { + timeout = time.Until(deadline) + } + dialTimeout := timeout + + nbDialer := nbnet.NewDialer() + + dialer := &net.Dialer{ + Control: func(network, address string, c syscall.RawConn) error { + return nbDialer.Control(network, address, c) + }, + Timeout: dialTimeout, + } + + upstreamExchangeClient := &dns.Client{ + Dialer: dialer, + } + + return upstreamExchangeClient.Exchange(r, upstream) +} + +func (u *upstreamResolver) isLocalResolver(upstream string) bool { + if u.hostsDNSHolder.isContain(upstream) { + return true + } + return false +} diff --git a/client/internal/dns/upstream_nonios.go b/client/internal/dns/upstream_general.go similarity index 67% rename from client/internal/dns/upstream_nonios.go rename to client/internal/dns/upstream_general.go index 22bd24ca9..a29350f8c 100644 --- a/client/internal/dns/upstream_nonios.go +++ b/client/internal/dns/upstream_general.go @@ -1,4 +1,4 @@ -//go:build !ios +//go:build !android && !ios package dns @@ -12,7 +12,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" ) -type upstreamResolverNonIOS struct { +type upstreamResolver struct { *upstreamResolverBase } @@ -22,16 +22,17 @@ func newUpstreamResolver( _ net.IP, _ *net.IPNet, statusRecorder *peer.Status, -) (*upstreamResolverNonIOS, error) { + _ *hostsDNSHolder, +) (*upstreamResolver, error) { upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) - nonIOS := &upstreamResolverNonIOS{ + nonIOS := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, } upstreamResolverBase.upstreamClient = nonIOS return nonIOS, nil } -func (u *upstreamResolverNonIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { +func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { upstreamExchangeClient := &dns.Client{} return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index c9d3bb942..0c01a013e 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -28,6 +28,7 @@ func newUpstreamResolver( ip net.IP, net *net.IPNet, statusRecorder *peer.Status, + _ *hostsDNSHolder, ) (*upstreamResolverIOS, error) { upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 77851dd9d..c1251dcc1 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) - resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil) + resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil) resolver.upstreamServers = testCase.InputServers resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 9f0f74213..9ad423ab9 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -155,7 +155,7 @@ func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeL // InitialRouteRange return the list of initial routes. It used by mobile systems func (m *DefaultManager) InitialRouteRange() []string { - return m.notifier.initialRouteRanges() + return m.notifier.getInitialRouteRanges() } // GetRouteSelector returns the route selector @@ -261,10 +261,7 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou func isPrefixSupported(prefix netip.Prefix) bool { if !nbnet.CustomRoutingDisabled() { - switch runtime.GOOS { - case "linux", "windows", "darwin", "ios": - return true - } + return true } // If prefix is too small, lets assume it is a possible default prefix which is not yet supported diff --git a/client/internal/routemanager/notifier.go b/client/internal/routemanager/notifier.go index d0c02612e..20c7c333a 100644 --- a/client/internal/routemanager/notifier.go +++ b/client/internal/routemanager/notifier.go @@ -10,8 +10,8 @@ import ( ) type notifier struct { - initialRouteRangers []string - routeRangers []string + initialRouteRanges []string + routeRanges []string listener listener.NetworkChangeListener listenerMux sync.Mutex @@ -33,7 +33,7 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) { nets = append(nets, r.Network.String()) } sort.Strings(nets) - n.initialRouteRangers = nets + n.initialRouteRanges = nets } func (n *notifier) onNewRoutes(idMap route.HAMap) { @@ -45,11 +45,11 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) { } sort.Strings(newNets) - if !n.hasDiff(n.initialRouteRangers, newNets) { + if !n.hasDiff(n.initialRouteRanges, newNets) { return } - n.routeRangers = newNets + n.routeRanges = newNets n.notify() } @@ -62,7 +62,7 @@ func (n *notifier) notify() { } go func(l listener.NetworkChangeListener) { - l.OnNetworkChanged(strings.Join(n.routeRangers, ",")) + l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ",")) }(n.listener) } @@ -78,6 +78,20 @@ func (n *notifier) hasDiff(a []string, b []string) bool { return false } -func (n *notifier) initialRouteRanges() []string { - return n.initialRouteRangers +func (n *notifier) getInitialRouteRanges() []string { + return addIPv6RangeIfNeeded(n.initialRouteRanges) +} + +// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route. +func addIPv6RangeIfNeeded(inputRanges []string) []string { + ranges := inputRanges + for _, r := range inputRanges { + // we are intentionally adding the ipv6 default range in case of ipv4 default range + // to ensure that all traffic is managed by the tunnel interface on android + if r == "0.0.0.0/0" { + ranges = append(ranges, "::/0") + break + } + } + return ranges } diff --git a/iface/tun_adapter.go b/iface/tun_adapter.go index c10eb3d19..adec93ed1 100644 --- a/iface/tun_adapter.go +++ b/iface/tun_adapter.go @@ -4,4 +4,5 @@ package iface type TunAdapter interface { ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error) UpdateAddr(address string) error + ProtectSocket(fd int32) bool } diff --git a/util/net/dialer_android.go b/util/net/dialer_android.go new file mode 100644 index 000000000..4cbded536 --- /dev/null +++ b/util/net/dialer_android.go @@ -0,0 +1,25 @@ +package net + +import ( + "syscall" + + log "github.com/sirupsen/logrus" +) + +func (d *Dialer) init() { + d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { + err := c.Control(func(fd uintptr) { + androidProtectSocketLock.Lock() + f := androidProtectSocket + androidProtectSocketLock.Unlock() + if f == nil { + return + } + ok := f(int32(fd)) + if !ok { + log.Errorf("failed to protect socket: %d", fd) + } + }) + return err + } +} diff --git a/util/net/dialer_mobile.go b/util/net/dialer_ios.go similarity index 91% rename from util/net/dialer_mobile.go rename to util/net/dialer_ios.go index b95aaa973..0541979f6 100644 --- a/util/net/dialer_mobile.go +++ b/util/net/dialer_ios.go @@ -1,5 +1,3 @@ -//go:build android || ios - package net import ( diff --git a/util/net/dialer_generic.go b/util/net/dialer_nonios.go similarity index 98% rename from util/net/dialer_generic.go rename to util/net/dialer_nonios.go index 1e217da13..7a5de7587 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_nonios.go @@ -1,4 +1,4 @@ -//go:build !android && !ios +//go:build !ios package net @@ -36,7 +36,7 @@ func AddDialerCloseHook(hook DialerCloseHookFunc) { dialerCloseHooks = append(dialerCloseHooks, hook) } -// RemoveDialerHook removes all dialer hooks. +// RemoveDialerHooks removes all dialer hooks. func RemoveDialerHooks() { dialerDialHooksMutex.Lock() defer dialerDialHooksMutex.Unlock() diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_nonlinux.go index 3254e6d06..c838441bd 100644 --- a/util/net/dialer_nonlinux.go +++ b/util/net/dialer_nonlinux.go @@ -1,4 +1,4 @@ -//go:build !linux || android +//go:build !linux package net diff --git a/util/net/listener_android.go b/util/net/listener_android.go new file mode 100644 index 000000000..d4167ad53 --- /dev/null +++ b/util/net/listener_android.go @@ -0,0 +1,26 @@ +package net + +import ( + "syscall" + + log "github.com/sirupsen/logrus" +) + +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { + err := c.Control(func(fd uintptr) { + androidProtectSocketLock.Lock() + f := androidProtectSocket + androidProtectSocketLock.Unlock() + if f == nil { + return + } + ok := f(int32(fd)) + if !ok { + log.Errorf("failed to protect listener socket: %d", fd) + } + }) + return err + } +} diff --git a/util/net/listener_mobile.go b/util/net/listener_ios.go similarity index 85% rename from util/net/listener_mobile.go rename to util/net/listener_ios.go index 0dbbb360b..5c90c2161 100644 --- a/util/net/listener_mobile.go +++ b/util/net/listener_ios.go @@ -1,4 +1,4 @@ -//go:build android || ios +//go:build ios package net diff --git a/util/net/listener_generic.go b/util/net/listener_nonios.go similarity index 99% rename from util/net/listener_generic.go rename to util/net/listener_nonios.go index 7847a29c7..ae4be3494 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_nonios.go @@ -1,4 +1,4 @@ -//go:build !android && !ios +//go:build !ios package net diff --git a/util/net/listener_nonlinux.go b/util/net/listener_nonlinux.go index fb6eadaaa..14a6be49d 100644 --- a/util/net/listener_nonlinux.go +++ b/util/net/listener_nonlinux.go @@ -1,4 +1,4 @@ -//go:build !linux || android +//go:build !linux package net diff --git a/util/net/protectsocket_android.go b/util/net/protectsocket_android.go new file mode 100644 index 000000000..64fb45aa4 --- /dev/null +++ b/util/net/protectsocket_android.go @@ -0,0 +1,14 @@ +package net + +import "sync" + +var ( + androidProtectSocketLock sync.Mutex + androidProtectSocket func(fd int32) bool +) + +func SetAndroidProtectSocketFn(f func(fd int32) bool) { + androidProtectSocketLock.Lock() + androidProtectSocket = f + androidProtectSocketLock.Unlock() +}