mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 10:50:45 +01:00
ffffffff
This commit is contained in:
parent
9d820f1eae
commit
d802b7b9ba
@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
log.Debugf("registering dns handler for pattern: %s", pattern)
|
||||||
s.dnsMux.Handle(pattern, handler)
|
s.dnsMux.Handle(pattern, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ func (s *ServiceViaMemory) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||||
|
log.Debugf("registering dns handler for pattern: %s", pattern)
|
||||||
s.dnsMux.Handle(pattern, handler)
|
s.dnsMux.Handle(pattern, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@ -137,10 +136,6 @@ type Engine struct {
|
|||||||
TURNs []*stun.URI
|
TURNs []*stun.URI
|
||||||
stunTurn atomic.Value
|
stunTurn atomic.Value
|
||||||
|
|
||||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
|
||||||
clientRoutes route.HAMap
|
|
||||||
clientRoutesMu sync.RWMutex
|
|
||||||
|
|
||||||
clientCtx context.Context
|
clientCtx context.Context
|
||||||
clientCancel context.CancelFunc
|
clientCancel context.CancelFunc
|
||||||
|
|
||||||
@ -300,10 +295,6 @@ func (e *Engine) Stop() error {
|
|||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.clientRoutesMu.Lock()
|
|
||||||
e.clientRoutes = nil
|
|
||||||
e.clientRoutesMu.Unlock()
|
|
||||||
|
|
||||||
if e.cancel != nil {
|
if e.cancel != nil {
|
||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
@ -383,6 +374,7 @@ func (e *Engine) Start() error {
|
|||||||
initialRoutes,
|
initialRoutes,
|
||||||
e.stateManager,
|
e.stateManager,
|
||||||
dnsServer,
|
dnsServer,
|
||||||
|
e.peerConns,
|
||||||
)
|
)
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -812,15 +804,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
protoRoutes = []*mgmProto.Route{}
|
protoRoutes = []*mgmProto.Route{}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
if err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.clientRoutesMu.Lock()
|
|
||||||
e.clientRoutes = clientRoutes
|
|
||||||
e.clientRoutesMu.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
|
|
||||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||||
@ -868,8 +855,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
|
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to update dns server, err: %v", err)
|
log.Errorf("failed to update dns server, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1136,7 +1122,7 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
|
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
|
||||||
case sProto.Body_MODE:
|
case sProto.Body_MODE:
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1323,26 +1309,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the current routes from the route map
|
|
||||||
func (e *Engine) GetClientRoutes() route.HAMap {
|
|
||||||
e.clientRoutesMu.RLock()
|
|
||||||
defer e.clientRoutesMu.RUnlock()
|
|
||||||
|
|
||||||
return maps.Clone(e.clientRoutes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
|
||||||
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
|
||||||
e.clientRoutesMu.RLock()
|
|
||||||
defer e.clientRoutesMu.RUnlock()
|
|
||||||
|
|
||||||
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
|
|
||||||
for id, v := range e.clientRoutes {
|
|
||||||
routes[id.NetID()] = v
|
|
||||||
}
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRouteManager returns the route manager
|
// GetRouteManager returns the route manager
|
||||||
func (e *Engine) GetRouteManager() routemanager.Manager {
|
func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||||
return e.routeManager
|
return e.routeManager
|
||||||
@ -1506,7 +1472,7 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
|
|
||||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||||
var vpnRoutes []netip.Prefix
|
var vpnRoutes []netip.Prefix
|
||||||
for _, routes := range e.GetClientRoutes() {
|
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||||
if len(routes) > 0 && routes[0] != nil {
|
if len(routes) > 0 && routes[0] != nil {
|
||||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||||
}
|
}
|
||||||
|
@ -251,7 +251,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
engine.wgInterface = wgIface
|
engine.wgInterface = wgIface
|
||||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil)
|
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil)
|
||||||
_, _, err = engine.routeManager.Init()
|
_, _, err = engine.routeManager.Init()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
|
@ -65,6 +65,7 @@ func newClientNetworkWatcher(
|
|||||||
routeRefCounter *refcounter.RouteRefCounter,
|
routeRefCounter *refcounter.RouteRefCounter,
|
||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
dnsServer nbdns.Server,
|
dnsServer nbdns.Server,
|
||||||
|
peerConns map[string]*peer.Conn,
|
||||||
) *clientNetwork {
|
) *clientNetwork {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
@ -77,7 +78,16 @@ func newClientNetworkWatcher(
|
|||||||
routePeersNotifiers: make(map[string]chan struct{}),
|
routePeersNotifiers: make(map[string]chan struct{}),
|
||||||
routeUpdate: make(chan routesUpdate),
|
routeUpdate: make(chan routesUpdate),
|
||||||
peerStateUpdate: make(chan struct{}),
|
peerStateUpdate: make(chan struct{}),
|
||||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface, dnsServer),
|
handler: handlerFromRoute(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
dnsRouteInterval,
|
||||||
|
statusRecorder,
|
||||||
|
wgInterface,
|
||||||
|
dnsServer,
|
||||||
|
peerConns,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
@ -388,13 +398,29 @@ func handlerFromRoute(
|
|||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
wgInterface iface.IWGIface,
|
wgInterface iface.IWGIface,
|
||||||
dnsServer nbdns.Server,
|
dnsServer nbdns.Server,
|
||||||
|
peerConns map[string]*peer.Conn,
|
||||||
) RouteHandler {
|
) RouteHandler {
|
||||||
if rt.IsDynamic() {
|
if rt.IsDynamic() {
|
||||||
if useNewDNSRoute {
|
if useNewDNSRoute {
|
||||||
return dnsinterceptor.New(rt, routeRefCounter, allowedIPsRefCounter, statusRecorder, dnsServer)
|
return dnsinterceptor.New(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
statusRecorder,
|
||||||
|
dnsServer,
|
||||||
|
peerConns,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
return dynamic.NewRoute(
|
||||||
|
rt,
|
||||||
|
routeRefCounter,
|
||||||
|
allowedIPsRefCounter,
|
||||||
|
dnsRouterInteval,
|
||||||
|
statusRecorder,
|
||||||
|
wgInterface,
|
||||||
|
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ type DnsInterceptor struct {
|
|||||||
dnsServer nbdns.Server
|
dnsServer nbdns.Server
|
||||||
currentPeerKey string
|
currentPeerKey string
|
||||||
interceptedIPs map[string]netip.Prefix
|
interceptedIPs map[string]netip.Prefix
|
||||||
|
peerConns map[string]*peer.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(
|
func New(
|
||||||
@ -32,6 +33,7 @@ func New(
|
|||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
dnsServer nbdns.Server,
|
dnsServer nbdns.Server,
|
||||||
|
peerConns map[string]*peer.Conn,
|
||||||
) *DnsInterceptor {
|
) *DnsInterceptor {
|
||||||
return &DnsInterceptor{
|
return &DnsInterceptor{
|
||||||
route: rt,
|
route: rt,
|
||||||
@ -40,6 +42,7 @@ func New(
|
|||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
dnsServer: dnsServer,
|
dnsServer: dnsServer,
|
||||||
interceptedIPs: make(map[string]netip.Prefix),
|
interceptedIPs: make(map[string]netip.Prefix),
|
||||||
|
peerConns: peerConns,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
@ -34,9 +35,11 @@ import (
|
|||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
TriggerSelection(route.HAMap)
|
TriggerSelection(route.HAMap)
|
||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
|
GetClientRoutes() route.HAMap
|
||||||
|
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
EnableServerRouter(firewall firewall.Manager) error
|
EnableServerRouter(firewall firewall.Manager) error
|
||||||
@ -61,7 +64,10 @@ type DefaultManager struct {
|
|||||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
|
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||||
|
clientRoutes route.HAMap
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
peerConns map[string]*peer.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(
|
func NewManager(
|
||||||
@ -74,6 +80,7 @@ func NewManager(
|
|||||||
initialRoutes []*route.Route,
|
initialRoutes []*route.Route,
|
||||||
stateManager *statemanager.Manager,
|
stateManager *statemanager.Manager,
|
||||||
dnsServer dns.Server,
|
dnsServer dns.Server,
|
||||||
|
peerConns map[string]*peer.Conn,
|
||||||
) *DefaultManager {
|
) *DefaultManager {
|
||||||
mCTX, cancel := context.WithCancel(ctx)
|
mCTX, cancel := context.WithCancel(ctx)
|
||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
@ -92,6 +99,7 @@ func NewManager(
|
|||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
stateManager: stateManager,
|
stateManager: stateManager,
|
||||||
dnsServer: dnsServer,
|
dnsServer: dnsServer,
|
||||||
|
peerConns: peerConns,
|
||||||
}
|
}
|
||||||
|
|
||||||
dm.routeRefCounter = refcounter.New(
|
dm.routeRefCounter = refcounter.New(
|
||||||
@ -120,7 +128,7 @@ func NewManager(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if runtime.GOOS == "android" {
|
if runtime.GOOS == "android" {
|
||||||
cr := dm.clientRoutes(initialRoutes)
|
cr := dm.initialClientRoutes(initialRoutes)
|
||||||
dm.notifier.SetInitialClientRoutes(cr)
|
dm.notifier.SetInitialClientRoutes(cr)
|
||||||
}
|
}
|
||||||
return dm
|
return dm
|
||||||
@ -211,15 +219,21 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.ctx = nil
|
m.ctx = nil
|
||||||
|
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
m.clientRoutes = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("not updating routes as context is closed")
|
log.Infof("not updating routes as context is closed")
|
||||||
return nil, nil, m.ctx.Err()
|
return nil
|
||||||
default:
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
@ -232,12 +246,13 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("update routes: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return newServerRoutesMap, newClientRoutesIDMap, nil
|
m.clientRoutes = newClientRoutesIDMap
|
||||||
}
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRouteChangeListener set RouteListener for route change Notifier
|
// SetRouteChangeListener set RouteListener for route change Notifier
|
||||||
@ -255,9 +270,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
return m.routeSelector
|
return m.routeSelector
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the client routes
|
// GetClientRoutes returns most recent list of clientRoutes received from the Management Service
|
||||||
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
|
func (m *DefaultManager) GetClientRoutes() route.HAMap {
|
||||||
return m.clientNetworks
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
return maps.Clone(m.clientRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||||
|
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes))
|
||||||
|
for id, v := range m.clientRoutes {
|
||||||
|
routes[id.NetID()] = v
|
||||||
|
}
|
||||||
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
||||||
@ -286,6 +316,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
m.routeRefCounter,
|
m.routeRefCounter,
|
||||||
m.allowedIPsRefCounter,
|
m.allowedIPsRefCounter,
|
||||||
m.dnsServer,
|
m.dnsServer,
|
||||||
|
m.peerConns,
|
||||||
)
|
)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
@ -315,16 +346,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
|||||||
for id, routes := range networks {
|
for id, routes := range networks {
|
||||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||||
if !found {
|
if !found {
|
||||||
clientNetworkWatcher = newClientNetworkWatcher(
|
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, nil)
|
||||||
m.ctx,
|
|
||||||
m.dnsRouteInterval,
|
|
||||||
m.wgInterface,
|
|
||||||
m.statusRecorder,
|
|
||||||
routes[0],
|
|
||||||
m.routeRefCounter,
|
|
||||||
m.allowedIPsRefCounter,
|
|
||||||
m.dnsServer,
|
|
||||||
)
|
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||||
}
|
}
|
||||||
@ -367,7 +389,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
|||||||
return newServerRoutesMap, newClientRoutesIDMap
|
return newServerRoutesMap, newClientRoutesIDMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||||
_, crMap := m.classifyRoutes(initialRoutes)
|
_, crMap := m.classifyRoutes(initialRoutes)
|
||||||
rs := make([]*route.Route, 0, len(crMap))
|
rs := make([]*route.Route, 0, len(crMap))
|
||||||
for _, routes := range crMap {
|
for _, routes := range crMap {
|
||||||
|
@ -424,7 +424,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
|
|
||||||
statusRecorder := peer.NewRecorder("https://mgm")
|
statusRecorder := peer.NewRecorder("https://mgm")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
|
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil)
|
||||||
|
|
||||||
_, _, err = routeManager.Init()
|
_, _, err = routeManager.Init()
|
||||||
|
|
||||||
@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(testCase.inputInitRoutes) > 0 {
|
if len(testCase.inputInitRoutes) > 0 {
|
||||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||||
require.NoError(t, err, "should update routes with init routes")
|
require.NoError(t, err, "should update routes with init routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||||
require.NoError(t, err, "should update routes")
|
require.NoError(t, err, "should update routes")
|
||||||
|
|
||||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||||
|
@ -2,7 +2,6 @@ package routemanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
@ -15,9 +14,11 @@ import (
|
|||||||
|
|
||||||
// MockManager is the mock instance of a route manager
|
// MockManager is the mock instance of a route manager
|
||||||
type MockManager struct {
|
type MockManager struct {
|
||||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
TriggerSelectionFunc func(haMap route.HAMap)
|
TriggerSelectionFunc func(haMap route.HAMap)
|
||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
|
GetClientRoutesFunc func() route.HAMap
|
||||||
|
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||||
StopFunc func(manager *statemanager.Manager)
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
if m.UpdateRoutesFunc != nil {
|
if m.UpdateRoutesFunc != nil {
|
||||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||||
}
|
}
|
||||||
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
||||||
@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
|
||||||
|
func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||||
|
if m.GetClientRoutesFunc != nil {
|
||||||
|
return m.GetClientRoutesFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||||
|
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
|
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||||
|
return m.GetClientRoutesWithNetIDFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Start mock implementation of Start from Manager interface
|
// Start mock implementation of Start from Manager interface
|
||||||
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
|
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L
|
|||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
routesMap := engine.GetClientRoutesWithNetID()
|
routesMap := engine.GetRouteManager().GetClientRoutesWithNetID()
|
||||||
routeSelector := engine.GetRouteManager().GetRouteSelector()
|
routeSelector := engine.GetRouteManager().GetRouteSelector()
|
||||||
|
|
||||||
var routes []*selectRoute
|
var routes []*selectRoute
|
||||||
@ -116,11 +116,12 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest)
|
|||||||
routeSelector.SelectAllRoutes()
|
routeSelector.SelectAllRoutes()
|
||||||
} else {
|
} else {
|
||||||
routes := toNetIDs(req.GetRouteIDs())
|
routes := toNetIDs(req.GetRouteIDs())
|
||||||
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
netIdRoutes := maps.Keys(engine.GetRouteManager().GetClientRoutesWithNetID())
|
||||||
|
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil {
|
||||||
return nil, fmt.Errorf("select routes: %w", err)
|
return nil, fmt.Errorf("select routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
routeManager.TriggerSelection(engine.GetRouteManager().GetClientRoutes())
|
||||||
|
|
||||||
return &proto.SelectRoutesResponse{}, nil
|
return &proto.SelectRoutesResponse{}, nil
|
||||||
}
|
}
|
||||||
@ -145,11 +146,12 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques
|
|||||||
routeSelector.DeselectAllRoutes()
|
routeSelector.DeselectAllRoutes()
|
||||||
} else {
|
} else {
|
||||||
routes := toNetIDs(req.GetRouteIDs())
|
routes := toNetIDs(req.GetRouteIDs())
|
||||||
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
|
netIdRoutes := maps.Keys(engine.GetRouteManager().GetClientRoutesWithNetID())
|
||||||
|
if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil {
|
||||||
return nil, fmt.Errorf("deselect routes: %w", err)
|
return nil, fmt.Errorf("deselect routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
routeManager.TriggerSelection(engine.GetClientRoutes())
|
routeManager.TriggerSelection(engine.GetRouteManager().GetClientRoutes())
|
||||||
|
|
||||||
return &proto.SelectRoutesResponse{}, nil
|
return &proto.SelectRoutesResponse{}, nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user