Feature/exit node Android (#1916)

Support exit node on Android.
With the protect socket function, we mark every connection that should be used out of VPN.
This commit is contained in:
Zoltan Papp 2024-05-07 12:28:30 +02:00 committed by GitHub
parent f309b120cd
commit c590518e0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 275 additions and 49 deletions

View File

@ -1,3 +1,5 @@
//go:build android
package android package android
import ( import (
@ -14,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util/net"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@ -59,6 +62,7 @@ type Client struct {
// NewClient instantiate a new Client // NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{ return &Client{
cfgFile: cfgFile, cfgFile: cfgFile,
deviceName: deviceName, deviceName: deviceName,

View File

@ -0,0 +1,63 @@
package dns
import (
"fmt"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
)
type hostsDNSHolder struct {
unprotectedDNSList map[string]struct{}
mutex sync.RWMutex
}
func newHostsDNSHolder() *hostsDNSHolder {
return &hostsDNSHolder{
unprotectedDNSList: make(map[string]struct{}),
}
}
func (h *hostsDNSHolder) set(list []string) {
h.mutex.Lock()
h.unprotectedDNSList = make(map[string]struct{})
for _, dns := range list {
dnsAddr, err := h.normalizeAddress(dns)
if err != nil {
continue
}
h.unprotectedDNSList[dnsAddr] = struct{}{}
}
h.mutex.Unlock()
}
func (h *hostsDNSHolder) get() map[string]struct{} {
h.mutex.RLock()
l := h.unprotectedDNSList
h.mutex.RUnlock()
return l
}
//nolint:unused
func (h *hostsDNSHolder) isContain(upstream string) bool {
h.mutex.RLock()
defer h.mutex.RUnlock()
_, ok := h.unprotectedDNSList[upstream]
return ok
}
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
a, err := netip.ParseAddr(addr)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
return "", err
}
if a.Is4() {
return fmt.Sprintf("%s:53", addr), nil
} else {
return fmt.Sprintf("[%s]:53", addr), nil
}
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"runtime"
"strings" "strings"
"sync" "sync"
@ -54,9 +55,8 @@ type DefaultServer struct {
currentConfig HostDNSConfig currentConfig HostDNSConfig
// permanent related properties // permanent related properties
permanent bool permanent bool
hostsDnsList []string hostsDNSHolder *hostsDNSHolder
hostsDnsListLock sync.Mutex
// make sense on mobile only // make sense on mobile only
searchDomainNotifier *notifier searchDomainNotifier *notifier
@ -113,8 +113,8 @@ func NewDefaultServerPermanentUpstream(
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.hostsDnsList = hostsDnsList
ds.addHostRootZone() ds.addHostRootZone()
ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort()) ds.currentConfig = dnsConfigToHostDNSConfig(config, ds.service.RuntimeIP(), ds.service.RuntimePort())
ds.searchDomainNotifier = newNotifier(ds.SearchDomains()) ds.searchDomainNotifier = newNotifier(ds.SearchDomains())
@ -147,6 +147,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
}, },
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
hostsDNSHolder: newHostsDNSHolder(),
} }
return defaultServer return defaultServer
@ -202,10 +203,8 @@ func (s *DefaultServer) Stop() {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
s.hostsDnsListLock.Lock() s.hostsDNSHolder.set(hostsDnsList)
defer s.hostsDnsListLock.Unlock()
s.hostsDnsList = hostsDnsList
_, ok := s.dnsMuxMap[nbdns.RootZone] _, ok := s.dnsMuxMap[nbdns.RootZone]
if ok { if ok {
log.Debugf("on new host DNS config but skip to apply it") log.Debugf("on new host DNS config but skip to apply it")
@ -374,6 +373,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
s.wgInterface.Address().IP, s.wgInterface.Address().IP,
s.wgInterface.Address().Network, s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err) return nil, fmt.Errorf("unable to create a new upstream resolver, error: %v", err)
@ -452,9 +452,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
_, found := muxUpdateMap[key] _, found := muxUpdateMap[key]
if !found { if !found {
if !isContainRootUpdate && key == nbdns.RootZone { if !isContainRootUpdate && key == nbdns.RootZone {
s.hostsDnsListLock.Lock()
s.addHostRootZone() s.addHostRootZone()
s.hostsDnsListLock.Unlock()
existingHandler.stop() existingHandler.stop()
} else { } else {
existingHandler.stop() existingHandler.stop()
@ -512,6 +510,7 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1 removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
s.service.DeregisterMux(nbdns.RootZone)
} }
for i, item := range s.currentConfig.Domains { for i, item := range s.currentConfig.Domains {
@ -521,10 +520,15 @@ func (s *DefaultServer) upstreamCallbacks(
removeIndex[item.Domain] = i removeIndex[item.Domain] = i
} }
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
} }
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone()
}
s.updateNSState(nsGroup, err, false) s.updateNSState(nsGroup, err, false)
} }
@ -545,6 +549,9 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary { if nsGroup.Primary {
s.currentConfig.RouteAll = true s.currentConfig.RouteAll = true
if runtime.GOOS == "android" {
s.service.RegisterMux(nbdns.RootZone, handler)
}
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
@ -562,25 +569,16 @@ func (s *DefaultServer) addHostRootZone() {
s.wgInterface.Address().IP, s.wgInterface.Address().IP,
s.wgInterface.Address().Network, s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder,
) )
if err != nil { if err != nil {
log.Errorf("unable to create a new upstream resolver, error: %v", err) log.Errorf("unable to create a new upstream resolver, error: %v", err)
return return
} }
handler.upstreamServers = make([]string, len(s.hostsDnsList))
for n, ua := range s.hostsDnsList {
a, err := netip.ParseAddr(ua)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", ua, err)
continue
}
ipString := ua handler.upstreamServers = make([]string, 0)
if !a.Is4() { for k := range s.hostsDNSHolder.get() {
ipString = fmt.Sprintf("[%s]", ua) handler.upstreamServers = append(handler.upstreamServers, k)
}
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ipString)
} }
handler.deactivate = func(error) {} handler.deactivate = func(error) {}
handler.reactivate = func() {} handler.reactivate = func() {}

View File

@ -0,0 +1,84 @@
package dns
import (
"context"
"net"
"syscall"
"time"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/util/net"
)
type upstreamResolver struct {
*upstreamResolverBase
hostsDNSHolder *hostsDNSHolder
}
// newUpstreamResolver in Android we need to distinguish the DNS servers to available through VPN or outside of VPN
// In case if the assigned DNS address is available only in the protected network then the resolver will time out at the
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
func newUpstreamResolver(
ctx context.Context,
_ string,
_ net.IP,
_ *net.IPNet,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
c := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
hostsDNSHolder: hostsDNSHolder,
}
upstreamResolverBase.upstreamClient = c
return c, nil
}
// exchange in case of Android if the upstream is a local resolver then we do not need to mark the socket as protected.
// In other case the DNS resolvation goes through the VPN, so we need to force to use the
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
if u.isLocalResolver(upstream) {
return u.exchangeWithoutVPN(ctx, upstream, r)
} else {
return u.exchangeWithinVPN(ctx, upstream, r)
}
}
func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
timeout := upstreamTimeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
dialTimeout := timeout
nbDialer := nbnet.NewDialer()
dialer := &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return nbDialer.Control(network, address, c)
},
Timeout: dialTimeout,
}
upstreamExchangeClient := &dns.Client{
Dialer: dialer,
}
return upstreamExchangeClient.Exchange(r, upstream)
}
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
if u.hostsDNSHolder.isContain(upstream) {
return true
}
return false
}

View File

@ -1,4 +1,4 @@
//go:build !ios //go:build !android && !ios
package dns package dns
@ -12,7 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
) )
type upstreamResolverNonIOS struct { type upstreamResolver struct {
*upstreamResolverBase *upstreamResolverBase
} }
@ -22,16 +22,17 @@ func newUpstreamResolver(
_ net.IP, _ net.IP,
_ *net.IPNet, _ *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) (*upstreamResolverNonIOS, error) { _ *hostsDNSHolder,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
nonIOS := &upstreamResolverNonIOS{ nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
} }
upstreamResolverBase.upstreamClient = nonIOS upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil return nonIOS, nil
} }
func (u *upstreamResolverNonIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{} upstreamExchangeClient := &dns.Client{}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
} }

View File

@ -28,6 +28,7 @@ func newUpstreamResolver(
ip net.IP, ip net.IP,
net *net.IPNet, net *net.IPNet,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder,
) (*upstreamResolverIOS, error) { ) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)

View File

@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil) resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil)
resolver.upstreamServers = testCase.InputServers resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {

View File

@ -155,7 +155,7 @@ func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeL
// InitialRouteRange return the list of initial routes. It used by mobile systems // InitialRouteRange return the list of initial routes. It used by mobile systems
func (m *DefaultManager) InitialRouteRange() []string { func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.initialRouteRanges() return m.notifier.getInitialRouteRanges()
} }
// GetRouteSelector returns the route selector // GetRouteSelector returns the route selector
@ -261,10 +261,7 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
func isPrefixSupported(prefix netip.Prefix) bool { func isPrefixSupported(prefix netip.Prefix) bool {
if !nbnet.CustomRoutingDisabled() { if !nbnet.CustomRoutingDisabled() {
switch runtime.GOOS { return true
case "linux", "windows", "darwin", "ios":
return true
}
} }
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported // If prefix is too small, lets assume it is a possible default prefix which is not yet supported

View File

@ -10,8 +10,8 @@ import (
) )
type notifier struct { type notifier struct {
initialRouteRangers []string initialRouteRanges []string
routeRangers []string routeRanges []string
listener listener.NetworkChangeListener listener listener.NetworkChangeListener
listenerMux sync.Mutex listenerMux sync.Mutex
@ -33,7 +33,7 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
nets = append(nets, r.Network.String()) nets = append(nets, r.Network.String())
} }
sort.Strings(nets) sort.Strings(nets)
n.initialRouteRangers = nets n.initialRouteRanges = nets
} }
func (n *notifier) onNewRoutes(idMap route.HAMap) { func (n *notifier) onNewRoutes(idMap route.HAMap) {
@ -45,11 +45,11 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) {
} }
sort.Strings(newNets) sort.Strings(newNets)
if !n.hasDiff(n.initialRouteRangers, newNets) { if !n.hasDiff(n.initialRouteRanges, newNets) {
return return
} }
n.routeRangers = newNets n.routeRanges = newNets
n.notify() n.notify()
} }
@ -62,7 +62,7 @@ func (n *notifier) notify() {
} }
go func(l listener.NetworkChangeListener) { go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.routeRangers, ",")) l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
}(n.listener) }(n.listener)
} }
@ -78,6 +78,20 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
return false return false
} }
func (n *notifier) initialRouteRanges() []string { func (n *notifier) getInitialRouteRanges() []string {
return n.initialRouteRangers return addIPv6RangeIfNeeded(n.initialRouteRanges)
}
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
func addIPv6RangeIfNeeded(inputRanges []string) []string {
ranges := inputRanges
for _, r := range inputRanges {
// we are intentionally adding the ipv6 default range in case of ipv4 default range
// to ensure that all traffic is managed by the tunnel interface on android
if r == "0.0.0.0/0" {
ranges = append(ranges, "::/0")
break
}
}
return ranges
} }

View File

@ -4,4 +4,5 @@ package iface
type TunAdapter interface { type TunAdapter interface {
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error) ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error)
UpdateAddr(address string) error UpdateAddr(address string) error
ProtectSocket(fd int32) bool
} }

View File

@ -0,0 +1,25 @@
package net
import (
"syscall"
log "github.com/sirupsen/logrus"
)
func (d *Dialer) init() {
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
err := c.Control(func(fd uintptr) {
androidProtectSocketLock.Lock()
f := androidProtectSocket
androidProtectSocketLock.Unlock()
if f == nil {
return
}
ok := f(int32(fd))
if !ok {
log.Errorf("failed to protect socket: %d", fd)
}
})
return err
}
}

View File

@ -1,5 +1,3 @@
//go:build android || ios
package net package net
import ( import (

View File

@ -1,4 +1,4 @@
//go:build !android && !ios //go:build !ios
package net package net
@ -36,7 +36,7 @@ func AddDialerCloseHook(hook DialerCloseHookFunc) {
dialerCloseHooks = append(dialerCloseHooks, hook) dialerCloseHooks = append(dialerCloseHooks, hook)
} }
// RemoveDialerHook removes all dialer hooks. // RemoveDialerHooks removes all dialer hooks.
func RemoveDialerHooks() { func RemoveDialerHooks() {
dialerDialHooksMutex.Lock() dialerDialHooksMutex.Lock()
defer dialerDialHooksMutex.Unlock() defer dialerDialHooksMutex.Unlock()

View File

@ -1,4 +1,4 @@
//go:build !linux || android //go:build !linux
package net package net

View File

@ -0,0 +1,26 @@
package net
import (
"syscall"
log "github.com/sirupsen/logrus"
)
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
err := c.Control(func(fd uintptr) {
androidProtectSocketLock.Lock()
f := androidProtectSocket
androidProtectSocketLock.Unlock()
if f == nil {
return
}
ok := f(int32(fd))
if !ok {
log.Errorf("failed to protect listener socket: %d", fd)
}
})
return err
}
}

View File

@ -1,4 +1,4 @@
//go:build android || ios //go:build ios
package net package net

View File

@ -1,4 +1,4 @@
//go:build !android && !ios //go:build !ios
package net package net

View File

@ -1,4 +1,4 @@
//go:build !linux || android //go:build !linux
package net package net

View File

@ -0,0 +1,14 @@
package net
import "sync"
var (
androidProtectSocketLock sync.Mutex
androidProtectSocket func(fd int32) bool
)
func SetAndroidProtectSocketFn(f func(fd int32) bool) {
androidProtectSocketLock.Lock()
androidProtectSocket = f
androidProtectSocketLock.Unlock()
}