mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-15 11:21:04 +01:00
351 lines
11 KiB
Go
351 lines
11 KiB
Go
package routemanager
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"net/url"
|
|
"runtime"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
"github.com/netbirdio/netbird/client/internal/listener"
|
|
"github.com/netbirdio/netbird/client/internal/peer"
|
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
|
"github.com/netbirdio/netbird/iface"
|
|
"github.com/netbirdio/netbird/route"
|
|
nbnet "github.com/netbirdio/netbird/util/net"
|
|
"github.com/netbirdio/netbird/version"
|
|
)
|
|
|
|
// Manager is a route manager interface
|
|
type Manager interface {
|
|
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
|
TriggerSelection(route.HAMap)
|
|
GetRouteSelector() *routeselector.RouteSelector
|
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
|
InitialRouteRange() []string
|
|
EnableServerRouter(firewall firewall.Manager) error
|
|
Stop()
|
|
}
|
|
|
|
// DefaultManager is the default instance of a route manager
|
|
type DefaultManager struct {
|
|
ctx context.Context
|
|
stop context.CancelFunc
|
|
mux sync.Mutex
|
|
clientNetworks map[route.HAUniqueID]*clientNetwork
|
|
routeSelector *routeselector.RouteSelector
|
|
serverRouter serverRouter
|
|
sysOps *systemops.SysOps
|
|
statusRecorder *peer.Status
|
|
wgInterface *iface.WGIface
|
|
pubKey string
|
|
notifier *notifier.Notifier
|
|
routeRefCounter *refcounter.RouteRefCounter
|
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
|
dnsRouteInterval time.Duration
|
|
}
|
|
|
|
func NewManager(
|
|
ctx context.Context,
|
|
pubKey string,
|
|
dnsRouteInterval time.Duration,
|
|
wgInterface *iface.WGIface,
|
|
statusRecorder *peer.Status,
|
|
initialRoutes []*route.Route,
|
|
) *DefaultManager {
|
|
mCTX, cancel := context.WithCancel(ctx)
|
|
notifier := notifier.NewNotifier()
|
|
sysOps := systemops.NewSysOps(wgInterface, notifier)
|
|
|
|
dm := &DefaultManager{
|
|
ctx: mCTX,
|
|
stop: cancel,
|
|
dnsRouteInterval: dnsRouteInterval,
|
|
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
|
routeSelector: routeselector.NewRouteSelector(),
|
|
sysOps: sysOps,
|
|
statusRecorder: statusRecorder,
|
|
wgInterface: wgInterface,
|
|
pubKey: pubKey,
|
|
notifier: notifier,
|
|
}
|
|
|
|
dm.routeRefCounter = refcounter.New(
|
|
func(prefix netip.Prefix, _ any) (any, error) {
|
|
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
|
},
|
|
func(prefix netip.Prefix, _ any) error {
|
|
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
|
|
},
|
|
)
|
|
|
|
dm.allowedIPsRefCounter = refcounter.New(
|
|
func(prefix netip.Prefix, peerKey string) (string, error) {
|
|
// save peerKey to use it in the remove function
|
|
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
|
|
},
|
|
func(prefix netip.Prefix, peerKey string) error {
|
|
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
|
|
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
|
|
return err
|
|
}
|
|
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
|
|
}
|
|
return nil
|
|
},
|
|
)
|
|
|
|
if runtime.GOOS == "android" {
|
|
cr := dm.clientRoutes(initialRoutes)
|
|
dm.notifier.SetInitialClientRoutes(cr)
|
|
}
|
|
return dm
|
|
}
|
|
|
|
// Init sets up the routing
|
|
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
|
if nbnet.CustomRoutingDisabled() {
|
|
return nil, nil, nil
|
|
}
|
|
|
|
if err := m.sysOps.CleanupRouting(); err != nil {
|
|
log.Warnf("Failed cleaning up routing: %v", err)
|
|
}
|
|
|
|
mgmtAddress := m.statusRecorder.GetManagementState().URL
|
|
signalAddress := m.statusRecorder.GetSignalState().URL
|
|
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
|
|
|
|
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
|
}
|
|
log.Info("Routing setup complete")
|
|
return beforePeerHook, afterPeerHook, nil
|
|
}
|
|
|
|
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
|
var err error
|
|
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Stop stops the manager watchers and clean firewall rules
|
|
func (m *DefaultManager) Stop() {
|
|
m.stop()
|
|
if m.serverRouter != nil {
|
|
m.serverRouter.cleanUp()
|
|
}
|
|
|
|
if m.routeRefCounter != nil {
|
|
if err := m.routeRefCounter.Flush(); err != nil {
|
|
log.Errorf("Error flushing route ref counter: %v", err)
|
|
}
|
|
}
|
|
if m.allowedIPsRefCounter != nil {
|
|
if err := m.allowedIPsRefCounter.Flush(); err != nil {
|
|
log.Errorf("Error flushing allowed IPs ref counter: %v", err)
|
|
}
|
|
}
|
|
|
|
if !nbnet.CustomRoutingDisabled() {
|
|
if err := m.sysOps.CleanupRouting(); err != nil {
|
|
log.Errorf("Error cleaning up routing: %v", err)
|
|
} else {
|
|
log.Info("Routing cleanup complete")
|
|
}
|
|
}
|
|
|
|
m.ctx = nil
|
|
}
|
|
|
|
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
|
select {
|
|
case <-m.ctx.Done():
|
|
log.Infof("not updating routes as context is closed")
|
|
return nil, nil, m.ctx.Err()
|
|
default:
|
|
m.mux.Lock()
|
|
defer m.mux.Unlock()
|
|
|
|
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
|
|
|
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
|
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
|
m.notifier.OnNewRoutes(filteredClientRoutes)
|
|
|
|
if m.serverRouter != nil {
|
|
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("update routes: %w", err)
|
|
}
|
|
}
|
|
|
|
return newServerRoutesMap, newClientRoutesIDMap, nil
|
|
}
|
|
}
|
|
|
|
// SetRouteChangeListener set RouteListener for route change Notifier
|
|
func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) {
|
|
m.notifier.SetListener(listener)
|
|
}
|
|
|
|
// InitialRouteRange return the list of initial routes. It used by mobile systems
|
|
func (m *DefaultManager) InitialRouteRange() []string {
|
|
return m.notifier.GetInitialRouteRanges()
|
|
}
|
|
|
|
// GetRouteSelector returns the route selector
|
|
func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
|
return m.routeSelector
|
|
}
|
|
|
|
// GetClientRoutes returns the client routes
|
|
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
|
|
return m.clientNetworks
|
|
}
|
|
|
|
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
|
func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|
m.mux.Lock()
|
|
defer m.mux.Unlock()
|
|
|
|
networks = m.routeSelector.FilterSelected(networks)
|
|
|
|
m.notifier.OnNewRoutes(networks)
|
|
|
|
m.stopObsoleteClients(networks)
|
|
|
|
for id, routes := range networks {
|
|
if _, found := m.clientNetworks[id]; found {
|
|
// don't touch existing client network watchers
|
|
continue
|
|
}
|
|
|
|
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
|
m.clientNetworks[id] = clientNetworkWatcher
|
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
|
}
|
|
}
|
|
|
|
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
|
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
|
|
for id, client := range m.clientNetworks {
|
|
if _, ok := networks[id]; !ok {
|
|
log.Debugf("Stopping client network watcher, %s", id)
|
|
client.cancel()
|
|
delete(m.clientNetworks, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks route.HAMap) {
|
|
// removing routes that do not exist as per the update from the Management service.
|
|
m.stopObsoleteClients(networks)
|
|
|
|
for id, routes := range networks {
|
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
|
if !found {
|
|
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
|
m.clientNetworks[id] = clientNetworkWatcher
|
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
|
}
|
|
update := routesUpdate{
|
|
updateSerial: updateSerial,
|
|
routes: routes,
|
|
}
|
|
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
|
|
}
|
|
}
|
|
|
|
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
|
|
newClientRoutesIDMap := make(route.HAMap)
|
|
newServerRoutesMap := make(map[route.ID]*route.Route)
|
|
ownNetworkIDs := make(map[route.HAUniqueID]bool)
|
|
|
|
for _, newRoute := range newRoutes {
|
|
haID := newRoute.GetHAUniqueID()
|
|
if newRoute.Peer == m.pubKey {
|
|
ownNetworkIDs[haID] = true
|
|
// only linux is supported for now
|
|
if runtime.GOOS != "linux" {
|
|
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
|
continue
|
|
}
|
|
newServerRoutesMap[newRoute.ID] = newRoute
|
|
}
|
|
}
|
|
|
|
for _, newRoute := range newRoutes {
|
|
haID := newRoute.GetHAUniqueID()
|
|
if !ownNetworkIDs[haID] {
|
|
if !isRouteSupported(newRoute) {
|
|
continue
|
|
}
|
|
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
|
|
}
|
|
}
|
|
|
|
return newServerRoutesMap, newClientRoutesIDMap
|
|
}
|
|
|
|
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
|
_, crMap := m.classifyRoutes(initialRoutes)
|
|
rs := make([]*route.Route, 0, len(crMap))
|
|
for _, routes := range crMap {
|
|
rs = append(rs, routes...)
|
|
}
|
|
return rs
|
|
}
|
|
|
|
func isRouteSupported(route *route.Route) bool {
|
|
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
|
return true
|
|
}
|
|
|
|
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
|
// we skip this prefix management
|
|
if route.Network.Bits() <= vars.MinRangeBits {
|
|
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
|
version.NetbirdVersion(), route.Network)
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs.
|
|
func resolveURLsToIPs(urls []string) []net.IP {
|
|
var ips []net.IP
|
|
for _, rawurl := range urls {
|
|
u, err := url.Parse(rawurl)
|
|
if err != nil {
|
|
log.Errorf("Failed to parse url %s: %v", rawurl, err)
|
|
continue
|
|
}
|
|
ipAddrs, err := net.LookupIP(u.Hostname())
|
|
if err != nil {
|
|
log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err)
|
|
continue
|
|
}
|
|
ips = append(ips, ipAddrs...)
|
|
}
|
|
return ips
|
|
}
|