Support DNS routes on iOS (#2254)

This commit is contained in:
pascal-fischer 2024-07-15 10:40:57 +02:00 committed by GitHub
parent 58fbc1249c
commit 47752e1573
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 185 additions and 26 deletions

View File

@ -16,6 +16,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
@ -50,7 +51,7 @@ type DefaultManager struct {
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
pubKey string pubKey string
notifier *notifier notifier *notifier.Notifier
routeRefCounter *refcounter.RouteRefCounter routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
@ -65,7 +66,8 @@ func NewManager(
initialRoutes []*route.Route, initialRoutes []*route.Route,
) *DefaultManager { ) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
sysOps := systemops.NewSysOps(wgInterface) notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(wgInterface, notifier)
dm := &DefaultManager{ dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
@ -77,7 +79,7 @@ func NewManager(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
notifier: newNotifier(), notifier: notifier,
} }
dm.routeRefCounter = refcounter.New( dm.routeRefCounter = refcounter.New(
@ -107,7 +109,7 @@ func NewManager(
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes) cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr) dm.notifier.SetInitialClientRoutes(cr)
} }
return dm return dm
} }
@ -186,7 +188,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes) m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes)
if m.serverRouter != nil { if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap) err := m.serverRouter.updateRoutes(newServerRoutesMap)
@ -199,14 +201,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
} }
} }
// SetRouteChangeListener set RouteListener for route change notifier // SetRouteChangeListener set RouteListener for route change Notifier
func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) { func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) {
m.notifier.setListener(listener) m.notifier.SetListener(listener)
} }
// InitialRouteRange return the list of initial routes. It used by mobile systems // InitialRouteRange return the list of initial routes. It used by mobile systems
func (m *DefaultManager) InitialRouteRange() []string { func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.getInitialRouteRanges() return m.notifier.GetInitialRouteRanges()
} }
// GetRouteSelector returns the route selector // GetRouteSelector returns the route selector
@ -226,7 +228,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
networks = m.routeSelector.FilterSelected(networks) networks = m.routeSelector.FilterSelected(networks)
m.notifier.onNewRoutes(networks) m.notifier.OnNewRoutes(networks)
m.stopObsoleteClients(networks) m.stopObsoleteClients(networks)

View File

@ -1,6 +1,7 @@
package routemanager package notifier
import ( import (
"net/netip"
"runtime" "runtime"
"sort" "sort"
"strings" "strings"
@ -10,7 +11,7 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type notifier struct { type Notifier struct {
initialRouteRanges []string initialRouteRanges []string
routeRanges []string routeRanges []string
@ -18,17 +19,17 @@ type notifier struct {
listenerMux sync.Mutex listenerMux sync.Mutex
} }
func newNotifier() *notifier { func NewNotifier() *Notifier {
return &notifier{} return &Notifier{}
} }
func (n *notifier) setListener(listener listener.NetworkChangeListener) { func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock() n.listenerMux.Lock()
defer n.listenerMux.Unlock() defer n.listenerMux.Unlock()
n.listener = listener n.listener = listener
} }
func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0) nets := make([]string, 0)
for _, r := range clientRoutes { for _, r := range clientRoutes {
nets = append(nets, r.Network.String()) nets = append(nets, r.Network.String())
@ -37,7 +38,10 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
n.initialRouteRanges = nets n.initialRouteRanges = nets
} }
func (n *notifier) onNewRoutes(idMap route.HAMap) { func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" {
return
}
newNets := make([]string, 0) newNets := make([]string, 0)
for _, routes := range idMap { for _, routes := range idMap {
for _, r := range routes { for _, r := range routes {
@ -62,7 +66,30 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) {
n.notify() n.notify()
} }
func (n *notifier) notify() { func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
}
n.routeRanges = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock() n.listenerMux.Lock()
defer n.listenerMux.Unlock() defer n.listenerMux.Unlock()
if n.listener == nil { if n.listener == nil {
@ -74,7 +101,7 @@ func (n *notifier) notify() {
}(n.listener) }(n.listener)
} }
func (n *notifier) hasDiff(a []string, b []string) bool { func (n *Notifier) hasDiff(a []string, b []string) bool {
if len(a) != len(b) { if len(a) != len(b) {
return true return true
} }
@ -86,7 +113,7 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
return false return false
} }
func (n *notifier) getInitialRouteRanges() []string { func (n *Notifier) GetInitialRouteRanges() []string {
return addIPv6RangeIfNeeded(n.initialRouteRanges) return addIPv6RangeIfNeeded(n.initialRouteRanges)
} }

View File

@ -3,7 +3,9 @@ package systemops
import ( import (
"net" "net"
"net/netip" "net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
@ -18,10 +20,19 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct { type SysOps struct {
refCounter *ExclusionCounter refCounter *ExclusionCounter
wgInterface *iface.WGIface wgInterface *iface.WGIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
prefixes map[netip.Prefix]struct{}
//nolint
mu sync.Mutex
// notifier is used to notify the system of route changes (also used on mobile)
notifier *notifier.Notifier
} }
func NewSysOps(wgInterface *iface.WGIface) *SysOps { func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{ return &SysOps{
wgInterface: wgInterface, wgInterface: wgInterface,
notifier: notifier,
} }
} }

View File

@ -1,4 +1,4 @@
//go:build ios || android //go:build android
package systemops package systemops

View File

@ -36,7 +36,7 @@ func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0") baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"} intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil) r := NewSysOps(nil, nil)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 1024; i++ { for i := 0; i < 1024; i++ {

View File

@ -68,7 +68,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
r := NewSysOps(wgInterface) r := NewSysOps(wgInterface, nil)
_, _, err = r.SetupRouting(nil) _, _, err = r.SetupRouting(nil)
require.NoError(t, err) require.NoError(t, err)
@ -224,7 +224,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.NoError(t, err, "InterfaceByName should not return err") require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface) r := NewSysOps(wgInterface, nil)
// Prepare the environment // Prepare the environment
if testCase.preExistingPrefix.IsValid() { if testCase.preExistingPrefix.IsValid() {
@ -379,7 +379,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close()) assert.NoError(t, wgInterface.Close())
}) })
r := NewSysOps(wgInterface) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil) _, _, err := r.SetupRouting(nil)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {

View File

@ -0,0 +1,64 @@
//go:build ios
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil
}
func (r *SysOps) CleanupRouting() error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
r.notify()
return nil
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes[prefix] = struct{}{}
r.notify()
return nil
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.prefixes, prefix)
r.notify()
return nil
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{}
}
func (r *SysOps) notify() {
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
for prefix := range r.prefixes {
prefixes = append(prefixes, prefix)
}
r.notifier.OnNewPrefixes(prefixes)
}

View File

@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -47,6 +48,7 @@ type CustomLogger interface {
type selectRoute struct { type selectRoute struct {
NetID string NetID string
Network netip.Prefix Network netip.Prefix
Domains domain.List
Selected bool Selected bool
} }
@ -279,6 +281,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
route := &selectRoute{ route := &selectRoute{
NetID: string(id), NetID: string(id),
Network: rt[0].Network, Network: rt[0].Network,
Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id), Selected: routeSelector.IsSelected(id),
} }
routes = append(routes, route) routes = append(routes, route)
@ -299,17 +302,40 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
return iPrefix < jPrefix return iPrefix < jPrefix
}) })
resolvedDomains := c.recorder.GetResolvedDomainsStates()
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
var routeSelection []RoutesSelectionInfo var routeSelection []RoutesSelectionInfo
for _, r := range routes { for _, r := range routes {
domainList := make([]DomainInfo, 0)
for _, d := range r.Domains {
domainResp := DomainInfo{
Domain: d.SafeString(),
}
if prefixes, exists := resolvedDomains[d]; exists {
var ipStrings []string
for _, prefix := range prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
}
domainList = append(domainList, domainResp)
}
domainDetails := DomainDetails{items: domainList}
routeSelection = append(routeSelection, RoutesSelectionInfo{ routeSelection = append(routeSelection, RoutesSelectionInfo{
ID: r.NetID, ID: r.NetID,
Network: r.Network.String(), Network: r.Network.String(),
Domains: &domainDetails,
Selected: r.Selected, Selected: r.Selected,
}) })
} }
routeSelectionDetails := RoutesSelectionDetails{items: routeSelection} routeSelectionDetails := RoutesSelectionDetails{items: routeSelection}
return &routeSelectionDetails, nil return &routeSelectionDetails
} }
func (c *Client) SelectRoute(id string) error { func (c *Client) SelectRoute(id string) error {

View File

@ -16,9 +16,25 @@ type RoutesSelectionDetails struct {
type RoutesSelectionInfo struct { type RoutesSelectionInfo struct {
ID string ID string
Network string Network string
Domains *DomainDetails
Selected bool Selected bool
} }
type DomainCollection interface {
Add(s DomainInfo) DomainCollection
Get(i int) *DomainInfo
Size() int
}
type DomainDetails struct {
items []DomainInfo
}
type DomainInfo struct {
Domain string
ResolvedIPs string
}
// Add new PeerInfo to the collection // Add new PeerInfo to the collection
func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails { func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails {
array.items = append(array.items, s) array.items = append(array.items, s)
@ -34,3 +50,16 @@ func (array RoutesSelectionDetails) Get(i int) *RoutesSelectionInfo {
func (array RoutesSelectionDetails) Size() int { func (array RoutesSelectionDetails) Size() int {
return len(array.items) return len(array.items)
} }
func (array DomainDetails) Add(s DomainInfo) DomainCollection {
array.items = append(array.items, s)
return array
}
func (array DomainDetails) Get(i int) *DomainInfo {
return &array.items[i]
}
func (array DomainDetails) Size() int {
return len(array.items)
}