Refactor Route IDs (#1891)

This commit is contained in:
Viktor Liu 2024-05-06 14:47:49 +02:00 committed by GitHub
parent 6a4935139d
commit 4e7c17756c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 320 additions and 292 deletions

View File

@ -112,7 +112,7 @@ type Engine struct {
TURNs []*stun.URI TURNs []*stun.URI
// clientRoutes is the most recent list of clientRoutes received from the Management Service // clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes map[string][]*route.Route clientRoutes route.HAMap
cancel context.CancelFunc cancel context.CancelFunc
@ -736,9 +736,9 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
for _, protoRoute := range protoRoutes { for _, protoRoute := range protoRoutes {
_, prefix, _ := route.ParseNetwork(protoRoute.Network) _, prefix, _ := route.ParseNetwork(protoRoute.Network)
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: protoRoute.ID, ID: route.ID(protoRoute.ID),
Network: prefix, Network: prefix,
NetID: protoRoute.NetID, NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer, Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric), Metric: int(protoRoute.Metric),
@ -1238,18 +1238,15 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
} }
// GetClientRoutes returns the current routes from the route map // GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() map[string][]*route.Route { func (e *Engine) GetClientRoutes() route.HAMap {
return e.clientRoutes return e.clientRoutes
} }
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route { func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
routes := make(map[string][]*route.Route, len(e.clientRoutes)) routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes { for id, v := range e.clientRoutes {
if i := strings.LastIndex(id, "-"); i != -1 { routes[id.NetID()] = v
id = id[:i]
}
routes[id] = v
} }
return routes return routes
} }

View File

@ -578,7 +578,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{} }{}
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
input.inputSerial = updateSerial input.inputSerial = updateSerial
input.inputRoutes = newRoutes input.inputRoutes = newRoutes
return nil, nil, testCase.inputErr return nil, nil, testCase.inputErr
@ -743,7 +743,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
return nil, nil, nil return nil, nil, nil
}, },
} }

View File

@ -33,7 +33,7 @@ type clientNetwork struct {
stop context.CancelFunc stop context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
routes map[string]*route.Route routes map[route.ID]*route.Route
routeUpdate chan routesUpdate routeUpdate chan routesUpdate
peerStateUpdate chan struct{} peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{} routePeersNotifiers map[string]chan struct{}
@ -50,7 +50,7 @@ func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, st
stop: cancel, stop: cancel,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
routes: make(map[string]*route.Route), routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
@ -59,8 +59,8 @@ func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, st
return client return client
} }
func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
routePeerStatuses := make(map[string]routerPeerStatus) routePeerStatuses := make(map[route.ID]routerPeerStatus)
for _, r := range c.routes { for _, r := range c.routes {
peerStatus, err := c.statusRecorder.GetPeer(r.Peer) peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
if err != nil { if err != nil {
@ -90,12 +90,12 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
// * Latency: Routes with lower latency are prioritized. // * Latency: Routes with lower latency are prioritized.
// //
// It returns the ID of the selected optimal route. // It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
chosen := "" chosen := route.ID("")
chosenScore := float64(0) chosenScore := float64(0)
currScore := float64(0) currScore := float64(0)
currID := "" currID := route.ID("")
if c.chosenRoute != nil { if c.chosenRoute != nil {
currID = c.chosenRoute.ID currID = c.chosenRoute.ID
} }
@ -295,7 +295,7 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
} }
func (c *clientNetwork) handleUpdate(update routesUpdate) { func (c *clientNetwork) handleUpdate(update routesUpdate) {
updateMap := make(map[string]*route.Route) updateMap := make(map[route.ID]*route.Route)
for _, r := range update.routes { for _, r := range update.routes {
updateMap[r.ID] = r updateMap[r.ID] = r

View File

@ -12,21 +12,21 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
statuses map[string]routerPeerStatus statuses map[route.ID]routerPeerStatus
expectedRouteID string expectedRouteID route.ID
currentRoute string currentRoute route.ID
existingRoutes map[string]*route.Route existingRoutes map[route.ID]*route.Route
}{ }{
{ {
name: "one route", name: "one route",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
direct: true, direct: true,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -38,14 +38,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "one connected routes with relayed and direct", name: "one connected routes with relayed and direct",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: true, relayed: true,
direct: true, direct: true,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -57,14 +57,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "one connected routes with relayed and no direct", name: "one connected routes with relayed and no direct",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: true, relayed: true,
direct: false, direct: false,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -76,14 +76,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "no connected peers", name: "no connected peers",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: false, connected: false,
relayed: false, relayed: false,
direct: false, direct: false,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -95,7 +95,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "multiple connected peers with different metrics", name: "multiple connected peers with different metrics",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -107,7 +107,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
direct: true, direct: true,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: 9000, Metric: 9000,
@ -124,7 +124,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "multiple connected peers with one relayed", name: "multiple connected peers with one relayed",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -136,7 +136,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
direct: true, direct: true,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -153,7 +153,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "multiple connected peers with one direct", name: "multiple connected peers with one direct",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -165,7 +165,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
direct: false, direct: false,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -182,7 +182,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "multiple connected peers with different latencies", name: "multiple connected peers with different latencies",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
latency: 300 * time.Millisecond, latency: 300 * time.Millisecond,
@ -192,7 +192,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
latency: 10 * time.Millisecond, latency: 10 * time.Millisecond,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -209,7 +209,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "should ignore routes with latency 0", name: "should ignore routes with latency 0",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
latency: 0 * time.Millisecond, latency: 0 * time.Millisecond,
@ -219,7 +219,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
latency: 10 * time.Millisecond, latency: 10 * time.Millisecond,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -236,7 +236,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "current route with similar score and similar but slightly worse latency should not change", name: "current route with similar score and similar but slightly worse latency should not change",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -250,7 +250,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
latency: 10 * time.Millisecond, latency: 10 * time.Millisecond,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -267,7 +267,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "current route with bad score should be changed to route with better score", name: "current route with bad score should be changed to route with better score",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -281,7 +281,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
latency: 10 * time.Millisecond, latency: 10 * time.Millisecond,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -298,7 +298,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "current chosen route doesn't exist anymore", name: "current chosen route doesn't exist anymore",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -312,7 +312,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
latency: 10 * time.Millisecond, latency: 10 * time.Millisecond,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,

View File

@ -29,8 +29,8 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(map[string][]*route.Route) TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector GetRouteSelector() *routeselector.RouteSelector
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
@ -43,7 +43,7 @@ type DefaultManager struct {
ctx context.Context ctx context.Context
stop context.CancelFunc stop context.CancelFunc
mux sync.Mutex mux sync.Mutex
clientNetworks map[string]*clientNetwork clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector routeSelector *routeselector.RouteSelector
serverRouter serverRouter serverRouter serverRouter
statusRecorder *peer.Status statusRecorder *peer.Status
@ -57,7 +57,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
dm := &DefaultManager{ dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
clientNetworks: make(map[string]*clientNetwork), clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(), routeSelector: routeselector.NewRouteSelector(),
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
@ -122,7 +122,7 @@ func (m *DefaultManager) Stop() {
} }
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not updating routes as context is closed") log.Infof("not updating routes as context is closed")
@ -164,12 +164,12 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
} }
// GetClientRoutes returns the client routes // GetClientRoutes returns the client routes
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork { func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
return m.clientNetworks return m.clientNetworks
} }
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones // TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) { func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
@ -190,7 +190,7 @@ func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
} }
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list // stopObsoleteClients stops the client network watcher for the networks that are not in the new list
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) { func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
for id, client := range m.clientNetworks { for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok { if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id) log.Debugf("Stopping client network watcher, %s", id)
@ -200,7 +200,7 @@ func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route)
} }
} }
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks route.HAMap) {
// removing routes that do not exist as per the update from the Management service. // removing routes that do not exist as per the update from the Management service.
m.stopObsoleteClients(networks) m.stopObsoleteClients(networks)
@ -219,15 +219,15 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
} }
} }
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) { func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
newClientRoutesIDMap := make(map[string][]*route.Route) newClientRoutesIDMap := make(route.HAMap)
newServerRoutesMap := make(map[string]*route.Route) newServerRoutesMap := make(map[route.ID]*route.Route)
ownNetworkIDs := make(map[string]bool) ownNetworkIDs := make(map[route.HAUniqueID]bool)
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute) haID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey { if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true ownNetworkIDs[haID] = true
// only linux is supported for now // only linux is supported for now
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
@ -238,12 +238,12 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*r
} }
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute) haID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] { if !ownNetworkIDs[haID] {
if !isPrefixSupported(newRoute.Network) { if !isPrefixSupported(newRoute.Network) {
continue continue
} }
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
} }
} }

View File

@ -14,8 +14,8 @@ import (
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
type MockManager struct { type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelectionFunc func(map[string][]*route.Route) TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector GetRouteSelectorFunc func() *routeselector.RouteSelector
StopFunc func() StopFunc func()
} }
@ -30,14 +30,14 @@ func (m *MockManager) InitialRouteRange() []string {
} }
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface // UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
if m.UpdateRoutesFunc != nil { if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes) return m.UpdateRoutesFunc(updateSerial, newRoutes)
} }
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
} }
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) { func (m *MockManager) TriggerSelection(networks route.HAMap) {
if m.TriggerSelectionFunc != nil { if m.TriggerSelectionFunc != nil {
m.TriggerSelectionFunc(networks) m.TriggerSelectionFunc(networks)
} }

View File

@ -36,7 +36,7 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
n.initialRouteRangers = nets n.initialRouteRangers = nets
} }
func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) { func (n *notifier) onNewRoutes(idMap route.HAMap) {
newNets := make([]string, 0) newNets := make([]string, 0)
for _, routes := range idMap { for _, routes := range idMap {
for _, r := range routes { for _, r := range routes {

View File

@ -3,7 +3,7 @@ package routemanager
import "github.com/netbirdio/netbird/route" import "github.com/netbirdio/netbird/route"
type serverRouter interface { type serverRouter interface {
updateRoutes(map[string]*route.Route) error updateRoutes(map[route.ID]*route.Route) error
removeFromServerNetwork(*route.Route) error removeFromServerNetwork(*route.Route) error
cleanUp() cleanUp()
} }

View File

@ -19,7 +19,7 @@ import (
type defaultServerRouter struct { type defaultServerRouter struct {
mux sync.Mutex mux sync.Mutex
ctx context.Context ctx context.Context
routes map[string]*route.Route routes map[route.ID]*route.Route
firewall firewall.Manager firewall firewall.Manager
wgInterface *iface.WGIface wgInterface *iface.WGIface
statusRecorder *peer.Status statusRecorder *peer.Status
@ -28,15 +28,15 @@ type defaultServerRouter struct {
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
return &defaultServerRouter{ return &defaultServerRouter{
ctx: ctx, ctx: ctx,
routes: make(map[string]*route.Route), routes: make(map[route.ID]*route.Route),
firewall: firewall, firewall: firewall,
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
}, nil }, nil
} }
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error { func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
serverRoutesToRemove := make([]string, 0) serverRoutesToRemove := make([]route.ID, 0)
for routeID := range m.routes { for routeID := range m.routes {
update, found := routesMap[routeID] update, found := routesMap[routeID]
@ -168,7 +168,7 @@ func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair,
return firewall.RouterPair{}, err return firewall.RouterPair{}, err
} }
return firewall.RouterPair{ return firewall.RouterPair{
ID: route.ID, ID: string(route.ID),
Source: parsed.String(), Source: parsed.String(),
Destination: route.Network.Masked().String(), Destination: route.Network.Masked().String(),
Masquerade: route.Masquerade, Masquerade: route.Masquerade,

View File

@ -12,22 +12,22 @@ import (
) )
type RouteSelector struct { type RouteSelector struct {
selectedRoutes map[string]struct{} selectedRoutes map[route.NetID]struct{}
selectAll bool selectAll bool
} }
func NewRouteSelector() *RouteSelector { func NewRouteSelector() *RouteSelector {
return &RouteSelector{ return &RouteSelector{
selectedRoutes: map[string]struct{}{}, selectedRoutes: map[route.NetID]struct{}{},
// default selects all routes // default selects all routes
selectAll: true, selectAll: true,
} }
} }
// SelectRoutes updates the selected routes based on the provided route IDs. // SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error { func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
if !appendRoute { if !appendRoute {
rs.selectedRoutes = map[string]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
var multiErr *multierror.Error var multiErr *multierror.Error
@ -51,15 +51,15 @@ func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRout
// SelectAllRoutes sets the selector to select all routes. // SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() { func (rs *RouteSelector) SelectAllRoutes() {
rs.selectAll = true rs.selectAll = true
rs.selectedRoutes = map[string]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
// DeselectRoutes removes specific routes from the selection. // DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode. // If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error { func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
if rs.selectAll { if rs.selectAll {
rs.selectAll = false rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
for _, route := range allRoutes { for _, route := range allRoutes {
rs.selectedRoutes[route] = struct{}{} rs.selectedRoutes[route] = struct{}{}
} }
@ -85,11 +85,11 @@ func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) err
// DeselectAllRoutes deselects all routes, effectively disabling route selection. // DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() { func (rs *RouteSelector) DeselectAllRoutes() {
rs.selectAll = false rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
// IsSelected checks if a specific route is selected. // IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID string) bool { func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
if rs.selectAll { if rs.selectAll {
return true return true
} }
@ -98,18 +98,14 @@ func (rs *RouteSelector) IsSelected(routeID string) bool {
} }
// FilterSelected removes unselected routes from the provided map. // FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route { func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
if rs.selectAll { if rs.selectAll {
return maps.Clone(routes) return maps.Clone(routes)
} }
filtered := map[string][]*route.Route{} filtered := route.HAMap{}
for id, rt := range routes { for id, rt := range routes {
netID := id if rs.IsSelected(id.NetID()) {
if i := strings.LastIndex(id, "-"); i != -1 {
netID = id[:i]
}
if rs.IsSelected(netID) {
filtered[id] = rt filtered[id] = rt
} }
} }

View File

@ -12,53 +12,53 @@ import (
) )
func TestRouteSelector_SelectRoutes(t *testing.T) { func TestRouteSelector_SelectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"} allRoutes := []route.NetID{"route1", "route2", "route3"}
tests := []struct { tests := []struct {
name string name string
initialSelected []string initialSelected []route.NetID
selectRoutes []string selectRoutes []route.NetID
append bool append bool
wantSelected []string wantSelected []route.NetID
wantError bool wantError bool
}{ }{
{ {
name: "Select specific routes, initial all selected", name: "Select specific routes, initial all selected",
selectRoutes: []string{"route1", "route2"}, selectRoutes: []route.NetID{"route1", "route2"},
wantSelected: []string{"route1", "route2"}, wantSelected: []route.NetID{"route1", "route2"},
}, },
{ {
name: "Select specific routes, initial all deselected", name: "Select specific routes, initial all deselected",
initialSelected: []string{}, initialSelected: []route.NetID{},
selectRoutes: []string{"route1", "route2"}, selectRoutes: []route.NetID{"route1", "route2"},
wantSelected: []string{"route1", "route2"}, wantSelected: []route.NetID{"route1", "route2"},
}, },
{ {
name: "Select specific routes with initial selection", name: "Select specific routes with initial selection",
initialSelected: []string{"route1"}, initialSelected: []route.NetID{"route1"},
selectRoutes: []string{"route2", "route3"}, selectRoutes: []route.NetID{"route2", "route3"},
wantSelected: []string{"route2", "route3"}, wantSelected: []route.NetID{"route2", "route3"},
}, },
{ {
name: "Select non-existing route", name: "Select non-existing route",
selectRoutes: []string{"route1", "route4"}, selectRoutes: []route.NetID{"route1", "route4"},
wantSelected: []string{"route1"}, wantSelected: []route.NetID{"route1"},
wantError: true, wantError: true,
}, },
{ {
name: "Append route with initial selection", name: "Append route with initial selection",
initialSelected: []string{"route1"}, initialSelected: []route.NetID{"route1"},
selectRoutes: []string{"route2"}, selectRoutes: []route.NetID{"route2"},
append: true, append: true,
wantSelected: []string{"route1", "route2"}, wantSelected: []route.NetID{"route1", "route2"},
}, },
{ {
name: "Append route without initial selection", name: "Append route without initial selection",
selectRoutes: []string{"route2"}, selectRoutes: []route.NetID{"route2"},
append: true, append: true,
wantSelected: []string{"route2"}, wantSelected: []route.NetID{"route2"},
}, },
} }
@ -86,32 +86,32 @@ func TestRouteSelector_SelectRoutes(t *testing.T) {
} }
func TestRouteSelector_SelectAllRoutes(t *testing.T) { func TestRouteSelector_SelectAllRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"} allRoutes := []route.NetID{"route1", "route2", "route3"}
tests := []struct { tests := []struct {
name string name string
initialSelected []string initialSelected []route.NetID
wantSelected []string wantSelected []route.NetID
}{ }{
{ {
name: "Initial all selected", name: "Initial all selected",
wantSelected: []string{"route1", "route2", "route3"}, wantSelected: []route.NetID{"route1", "route2", "route3"},
}, },
{ {
name: "Initial all deselected", name: "Initial all deselected",
initialSelected: []string{}, initialSelected: []route.NetID{},
wantSelected: []string{"route1", "route2", "route3"}, wantSelected: []route.NetID{"route1", "route2", "route3"},
}, },
{ {
name: "Initial some selected", name: "Initial some selected",
initialSelected: []string{"route1"}, initialSelected: []route.NetID{"route1"},
wantSelected: []string{"route1", "route2", "route3"}, wantSelected: []route.NetID{"route1", "route2", "route3"},
}, },
{ {
name: "Initial all selected", name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"}, initialSelected: []route.NetID{"route1", "route2", "route3"},
wantSelected: []string{"route1", "route2", "route3"}, wantSelected: []route.NetID{"route1", "route2", "route3"},
}, },
} }
@ -134,39 +134,39 @@ func TestRouteSelector_SelectAllRoutes(t *testing.T) {
} }
func TestRouteSelector_DeselectRoutes(t *testing.T) { func TestRouteSelector_DeselectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"} allRoutes := []route.NetID{"route1", "route2", "route3"}
tests := []struct { tests := []struct {
name string name string
initialSelected []string initialSelected []route.NetID
deselectRoutes []string deselectRoutes []route.NetID
wantSelected []string wantSelected []route.NetID
wantError bool wantError bool
}{ }{
{ {
name: "Deselect specific routes, initial all selected", name: "Deselect specific routes, initial all selected",
deselectRoutes: []string{"route1", "route2"}, deselectRoutes: []route.NetID{"route1", "route2"},
wantSelected: []string{"route3"}, wantSelected: []route.NetID{"route3"},
}, },
{ {
name: "Deselect specific routes, initial all deselected", name: "Deselect specific routes, initial all deselected",
initialSelected: []string{}, initialSelected: []route.NetID{},
deselectRoutes: []string{"route1", "route2"}, deselectRoutes: []route.NetID{"route1", "route2"},
wantSelected: []string{}, wantSelected: []route.NetID{},
}, },
{ {
name: "Deselect specific routes with initial selection", name: "Deselect specific routes with initial selection",
initialSelected: []string{"route1", "route2"}, initialSelected: []route.NetID{"route1", "route2"},
deselectRoutes: []string{"route1", "route3"}, deselectRoutes: []route.NetID{"route1", "route3"},
wantSelected: []string{"route2"}, wantSelected: []route.NetID{"route2"},
}, },
{ {
name: "Deselect non-existing route", name: "Deselect non-existing route",
initialSelected: []string{"route1", "route2"}, initialSelected: []route.NetID{"route1", "route2"},
deselectRoutes: []string{"route1", "route4"}, deselectRoutes: []route.NetID{"route1", "route4"},
wantSelected: []string{"route2"}, wantSelected: []route.NetID{"route2"},
wantError: true, wantError: true,
}, },
} }
@ -195,32 +195,32 @@ func TestRouteSelector_DeselectRoutes(t *testing.T) {
} }
func TestRouteSelector_DeselectAll(t *testing.T) { func TestRouteSelector_DeselectAll(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"} allRoutes := []route.NetID{"route1", "route2", "route3"}
tests := []struct { tests := []struct {
name string name string
initialSelected []string initialSelected []route.NetID
wantSelected []string wantSelected []route.NetID
}{ }{
{ {
name: "Initial all selected", name: "Initial all selected",
wantSelected: []string{}, wantSelected: []route.NetID{},
}, },
{ {
name: "Initial all deselected", name: "Initial all deselected",
initialSelected: []string{}, initialSelected: []route.NetID{},
wantSelected: []string{}, wantSelected: []route.NetID{},
}, },
{ {
name: "Initial some selected", name: "Initial some selected",
initialSelected: []string{"route1", "route2"}, initialSelected: []route.NetID{"route1", "route2"},
wantSelected: []string{}, wantSelected: []route.NetID{},
}, },
{ {
name: "Initial all selected", name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"}, initialSelected: []route.NetID{"route1", "route2", "route3"},
wantSelected: []string{}, wantSelected: []route.NetID{},
}, },
} }
@ -245,7 +245,7 @@ func TestRouteSelector_DeselectAll(t *testing.T) {
func TestRouteSelector_IsSelected(t *testing.T) { func TestRouteSelector_IsSelected(t *testing.T) {
rs := routeselector.NewRouteSelector() rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"}) err := rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, []route.NetID{"route1", "route2", "route3"})
require.NoError(t, err) require.NoError(t, err)
assert.True(t, rs.IsSelected("route1")) assert.True(t, rs.IsSelected("route1"))
@ -257,10 +257,10 @@ func TestRouteSelector_IsSelected(t *testing.T) {
func TestRouteSelector_FilterSelected(t *testing.T) { func TestRouteSelector_FilterSelected(t *testing.T) {
rs := routeselector.NewRouteSelector() rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"}) err := rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, []route.NetID{"route1", "route2", "route3"})
require.NoError(t, err) require.NoError(t, err)
routes := map[string][]*route.Route{ routes := route.HAMap{
"route1-10.0.0.0/8": {}, "route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {}, "route2-192.168.0.0/16": {},
"route3-172.16.0.0/12": {}, "route3-172.16.0.0/12": {},
@ -268,7 +268,7 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
filtered := rs.FilterSelected(routes) filtered := rs.FilterSelected(routes)
assert.Equal(t, map[string][]*route.Route{ assert.Equal(t, route.HAMap{
"route1-10.0.0.0/8": {}, "route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {}, "route2-192.168.0.0/16": {},
}, filtered) }, filtered)

View File

@ -9,10 +9,11 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
) )
type selectRoute struct { type selectRoute struct {
NetID string NetID route.NetID
Network netip.Prefix Network netip.Prefix
Selected bool Selected bool
} }
@ -60,7 +61,7 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
var pbRoutes []*proto.Route var pbRoutes []*proto.Route
for _, route := range routes { for _, route := range routes {
pbRoutes = append(pbRoutes, &proto.Route{ pbRoutes = append(pbRoutes, &proto.Route{
ID: route.NetID, ID: string(route.NetID),
Network: route.Network.String(), Network: route.Network.String(),
Selected: route.Selected, Selected: route.Selected,
}) })
@ -81,7 +82,8 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest)
if req.GetAll() { if req.GetAll() {
routeSelector.SelectAllRoutes() routeSelector.SelectAllRoutes()
} else { } else {
if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { routes := toNetIDs(req.GetRouteIDs())
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("select routes: %w", err) return nil, fmt.Errorf("select routes: %w", err)
} }
} }
@ -100,7 +102,8 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques
if req.GetAll() { if req.GetAll() {
routeSelector.DeselectAllRoutes() routeSelector.DeselectAllRoutes()
} else { } else {
if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { routes := toNetIDs(req.GetRouteIDs())
if err := routeSelector.DeselectRoutes(routes, maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err) return nil, fmt.Errorf("deselect routes: %w", err)
} }
} }
@ -108,3 +111,11 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques
return &proto.SelectRoutesResponse{}, nil return &proto.SelectRoutesResponse{}, nil
} }
func toNetIDs(routes []string) []route.NetID {
var netIDs []route.NetID
for _, rt := range routes {
netIDs = append(netIDs, route.NetID(rt))
}
return netIDs
}

View File

@ -100,10 +100,10 @@ type AccountManager interface {
SavePolicy(accountID, userID string, policy *Policy) error SavePolicy(accountID, userID string, policy *Policy) error
DeletePolicy(accountID, policyID, userID string) error DeletePolicy(accountID, policyID, userID string) error
ListPolicies(accountID, userID string) ([]*Policy, error) ListPolicies(accountID, userID string) ([]*Policy, error)
GetRoute(accountID, routeID, userID string) (*route.Route, error) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error)
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error SaveRoute(accountID, userID string, route *route.Route) error
DeleteRoute(accountID, routeID, userID string) error DeleteRoute(accountID string, routeID route.ID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error) ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
@ -229,7 +229,7 @@ type Account struct {
Groups map[string]*nbgroup.Group `gorm:"-"` Groups map[string]*nbgroup.Group `gorm:"-"`
GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
Routes map[string]*route.Route `gorm:"-"` Routes map[route.ID]*route.Route `gorm:"-"`
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"`
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
@ -266,7 +266,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID) routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
peerRoutesMembership := make(lookupMap) peerRoutesMembership := make(lookupMap)
for _, r := range append(routes, peerDisabledRoutes...) { for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{} peerRoutesMembership[string(route.GetHAUniqueID(r))] = struct{}{}
} }
groupListMap := a.getPeerGroups(peerID) groupListMap := a.getPeerGroups(peerID)
@ -284,7 +284,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
var filteredRoutes []*route.Route var filteredRoutes []*route.Route
for _, r := range routes { for _, r := range routes {
_, found := peerMemberships[route.GetHAUniqueID(r)] _, found := peerMemberships[string(route.GetHAUniqueID(r))]
if !found { if !found {
filteredRoutes = append(filteredRoutes, r) filteredRoutes = append(filteredRoutes, r)
} }
@ -323,7 +323,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
return enabledRoutes, disabledRoutes return enabledRoutes, disabledRoutes
} }
seenRoute := make(map[string]struct{}) seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route, id string) { takeRoute := func(r *route.Route, id string) {
if _, ok := seenRoute[r.ID]; ok { if _, ok := seenRoute[r.ID]; ok {
@ -354,7 +354,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
newPeerRoute := r.Copy() newPeerRoute := r.Copy()
newPeerRoute.Peer = id newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil newPeerRoute.PeerGroups = nil
newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
takeRoute(newPeerRoute, id) takeRoute(newPeerRoute, id)
break break
} }
@ -693,7 +693,7 @@ func (a *Account) Copy() *Account {
policies = append(policies, policy.Copy()) policies = append(policies, policy.Copy())
} }
routes := map[string]*route.Route{} routes := map[route.ID]*route.Route{}
for id, r := range a.Routes { for id, r := range a.Routes {
routes[id] = r.Copy() routes[id] = r.Copy()
} }
@ -1946,7 +1946,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
network := NewNetwork() network := NewNetwork()
peers := make(map[string]*nbpeer.Peer) peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*User) users := make(map[string]*User)
routes := make(map[string]*route.Route) routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*SetupKey{} setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup) nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID) users[userID] = NewOwnerUser(userID)

View File

@ -1408,7 +1408,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
account := &Account{ account := &Account{
Routes: map[string]*route.Route{ Routes: map[route.ID]*route.Route{
"route-1": { "route-1": {
ID: "route-1", ID: "route-1",
Network: prefix, Network: prefix,
@ -1437,12 +1437,12 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
routes := account.GetRoutesByPrefix(prefix) routes := account.GetRoutesByPrefix(prefix)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2) routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes { for _, r := range routes {
routeIDs[r.ID] = struct{}{} routeIDs[r.ID] = struct{}{}
} }
assert.Contains(t, routeIDs, "route-1") assert.Contains(t, routeIDs, route.ID("route-1"))
assert.Contains(t, routeIDs, "route-2") assert.Contains(t, routeIDs, route.ID("route-2"))
} }
func TestAccount_GetRoutesToSync(t *testing.T) { func TestAccount_GetRoutesToSync(t *testing.T) {
@ -1459,7 +1459,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
}, },
Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[string]*route.Route{ Routes: map[route.ID]*route.Route{
"route-1": { "route-1": {
ID: "route-1", ID: "route-1",
Network: prefix, Network: prefix,
@ -1502,12 +1502,12 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2) routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes { for _, r := range routes {
routeIDs[r.ID] = struct{}{} routeIDs[r.ID] = struct{}{}
} }
assert.Contains(t, routeIDs, "route-2") assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, "route-3") assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
@ -1573,7 +1573,7 @@ func TestAccount_Copy(t *testing.T) {
SourcePostureChecks: make([]string, 0), SourcePostureChecks: make([]string, 0),
}, },
}, },
Routes: map[string]*route.Route{ Routes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
PeerGroups: []string{}, PeerGroups: []string{},

View File

@ -242,7 +242,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
for _, r := range account.Routes { for _, r := range account.Routes {
for _, g := range r.Groups { for _, g := range r.Groups {
if g == groupID { if g == groupID {
return &GroupLinkError{"route", r.NetID} return &GroupLinkError{"route", string(r.NetID)}
} }
} }
} }

View File

@ -107,7 +107,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
newRoute, err := h.accountManager.CreateRoute( newRoute, err := h.accountManager.CreateRoute(
account.Id, newPrefix.String(), peerId, peerGroupIds, account.Id, newPrefix.String(), peerId, peerGroupIds,
req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
) )
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -135,7 +135,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id) _, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
return return
@ -185,9 +185,9 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
} }
newRoute := &route.Route{ newRoute := &route.Route{
ID: routeID, ID: route.ID(routeID),
Network: newPrefix, Network: newPrefix,
NetID: req.NetworkId, NetID: route.NetID(req.NetworkId),
NetworkType: prefixType, NetworkType: prefixType,
Masquerade: req.Masquerade, Masquerade: req.Masquerade,
Metric: req.Metric, Metric: req.Metric,
@ -230,7 +230,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeleteRoute(account.Id, routeID, user.Id) err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
return return
@ -254,7 +254,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id) foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.NotFound, "route not found"), w) util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
return return
@ -265,9 +265,9 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
func toRouteResponse(serverRoute *route.Route) *api.Route { func toRouteResponse(serverRoute *route.Route) *api.Route {
route := &api.Route{ route := &api.Route{
Id: serverRoute.ID, Id: string(serverRoute.ID),
Description: serverRoute.Description, Description: serverRoute.Description,
NetworkId: serverRoute.NetID, NetworkId: string(serverRoute.NetID),
Enabled: serverRoute.Enabled, Enabled: serverRoute.Enabled,
Peer: &serverRoute.Peer, Peer: &serverRoute.Peer,
Network: serverRoute.Network.String(), Network: serverRoute.Network.String(),

View File

@ -82,7 +82,7 @@ var testingAccount = &server.Account{
func initRoutesTestData() *RoutesHandler { func initRoutesTestData() *RoutesHandler {
return &RoutesHandler{ return &RoutesHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) { GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) {
if routeID == existingRouteID { if routeID == existingRouteID {
return baseExistingRoute, nil return baseExistingRoute, nil
} }
@ -93,7 +93,7 @@ func initRoutesTestData() *RoutesHandler {
} }
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
}, },
CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) { CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
if peerID == notFoundPeerID { if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
@ -120,7 +120,7 @@ func initRoutesTestData() *RoutesHandler {
} }
return nil return nil
}, },
DeleteRouteFunc: func(_ string, routeID string, _ string) error { DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error {
if routeID != existingRouteID { if routeID != existingRouteID {
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID) return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
} }

View File

@ -67,7 +67,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
SourcePostureChecks: []string{"1"}, SourcePostureChecks: []string{"1"},
}, },
}, },
Routes: map[string]*route.Route{ Routes: map[route.ID]*route.Route{
"1": { "1": {
ID: "1", ID: "1",
PeerGroups: make([]string, 1), PeerGroups: make([]string, 1),
@ -151,7 +151,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
}, },
}, },
}, },
Routes: map[string]*route.Route{ Routes: map[route.ID]*route.Route{
"1": { "1": {
ID: "1", ID: "1",
PeerGroups: make([]string, 1), PeerGroups: make([]string, 1),

View File

@ -22,76 +22,76 @@ type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error) ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error) GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
DeletePeerFunc func(accountID, peerKey, userID string) error DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(accountID, userID string, group *group.Group) error SaveGroupFunc func(accountID, userID string, group *group.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*group.Group, error) ListGroupsFunc func(accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerID string) error GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerID string) error GroupDeletePeerFunc func(accountID, groupID, peerID string) error
DeleteRuleFunc func(accountID, ruleID, userID string) error DeleteRuleFunc func(accountID, ruleID, userID string) error
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error) GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
DeletePolicyFunc func(accountID, policyID, userID string) error DeletePolicyFunc func(accountID, policyID, userID string) error
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error SaveRouteFunc func(accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(accountID, routeID, userID string) error DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(accountID, userID string) error DeleteAccountFunc func(accountID, userID string) error
GetDNSDomainFunc func() string GetDNSDomainFunc func() string
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error) GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error) GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
GroupValidationFunc func(accountId string, groups []string) (bool, error) GroupValidationFunc func(accountId string, groups []string) (bool, error)
} }
@ -399,15 +399,15 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.
} }
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface // CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { func (am *MockAccountManager) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
if am.CreateRouteFunc != nil { if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(accountID, network, peerID, peerGroups, description, netID, masquerade, metric, groups, enabled, userID) return am.CreateRouteFunc(accountID, prefix, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
} }
// GetRoute mock implementation of GetRoute from server.AccountManager interface // GetRoute mock implementation of GetRoute from server.AccountManager interface
func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { func (am *MockAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
if am.GetRouteFunc != nil { if am.GetRouteFunc != nil {
return am.GetRouteFunc(accountID, routeID, userID) return am.GetRouteFunc(accountID, routeID, userID)
} }
@ -415,7 +415,7 @@ func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*rout
} }
// SaveRoute mock implementation of SaveRoute from server.AccountManager interface // SaveRoute mock implementation of SaveRoute from server.AccountManager interface
func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.Route) error { func (am *MockAccountManager) SaveRoute(accountID string, userID string, route *route.Route) error {
if am.SaveRouteFunc != nil { if am.SaveRouteFunc != nil {
return am.SaveRouteFunc(accountID, userID, route) return am.SaveRouteFunc(accountID, userID, route)
} }
@ -423,7 +423,7 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R
} }
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface // DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error { func (am *MockAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
if am.DeleteRouteFunc != nil { if am.DeleteRouteFunc != nil {
return am.DeleteRouteFunc(accountID, routeID, userID) return am.DeleteRouteFunc(accountID, routeID, userID)
} }

View File

@ -13,7 +13,7 @@ import (
) )
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -40,7 +40,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
} }
// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID, routeID string, peerGroupIDs []string, prefix netip.Prefix) error { func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix) error {
// routes can have both peer and peer_groups // routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefix(prefix) routesWithPrefix := account.GetRoutesByPrefix(prefix)
@ -56,7 +56,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
} }
if prefixRoute.Peer != "" { if prefixRoute.Peer != "" {
seenPeers[prefixRoute.ID] = true seenPeers[string(prefixRoute.ID)] = true
} }
for _, groupID := range prefixRoute.PeerGroups { for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true seenPeerGroups[groupID] = true
@ -114,7 +114,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
} }
// CreateRoute creates and saves a new route // CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -131,7 +131,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
} }
var newRoute route.Route var newRoute route.Route
newRoute.ID = xid.New().String() newRoute.ID = route.ID(xid.New().String())
prefixType, newPrefix, err := route.ParseNetwork(network) prefixType, newPrefix, err := route.ParseNetwork(network)
if err != nil { if err != nil {
@ -154,7 +154,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
} }
if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" { if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
@ -175,7 +175,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
newRoute.Groups = groups newRoute.Groups = groups
if account.Routes == nil { if account.Routes == nil {
account.Routes = make(map[string]*route.Route) account.Routes = make(map[route.ID]*route.Route)
} }
account.Routes[newRoute.ID] = &newRoute account.Routes[newRoute.ID] = &newRoute
@ -187,7 +187,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
am.updateAccountPeers(account) am.updateAccountPeers(account)
am.StoreEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
return &newRoute, nil return &newRoute, nil
} }
@ -209,7 +209,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
} }
if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" { if utf8.RuneCountInString(string(routeToSave.NetID)) > route.MaxNetIDChar || routeToSave.NetID == "" {
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
@ -248,13 +248,13 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
am.updateAccountPeers(account) am.updateAccountPeers(account)
am.StoreEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta()) am.StoreEvent(userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
return nil return nil
} }
// DeleteRoute deletes route with routeID // DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error { func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -274,7 +274,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
return err return err
} }
am.StoreEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta()) am.StoreEvent(userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
am.updateAccountPeers(account) am.updateAccountPeers(account)
@ -310,8 +310,8 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
func toProtocolRoute(route *route.Route) *proto.Route { func toProtocolRoute(route *route.Route) *proto.Route {
return &proto.Route{ return &proto.Route{
ID: route.ID, ID: string(route.ID),
NetID: route.NetID, NetID: string(route.NetID),
Network: route.Network.String(), Network: route.Network.String(),
NetworkType: int64(route.NetworkType), NetworkType: int64(route.NetworkType),
Peer: route.Peer, Peer: route.Peer,

View File

@ -40,7 +40,7 @@ const (
func TestCreateRoute(t *testing.T) { func TestCreateRoute(t *testing.T) {
type input struct { type input struct {
network string network string
netID string netID route.NetID
peerKey string peerKey string
peerGroupIDs []string peerGroupIDs []string
description string description string
@ -382,8 +382,8 @@ func TestSaveRoute(t *testing.T) {
invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34") invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34")
validMetric := 1000 validMetric := 1000
invalidMetric := 99999 invalidMetric := 99999
validNetID := "12345678901234567890qw" validNetID := route.NetID("12345678901234567890qw")
invalidNetID := "12345678901234567890qwertyuiopqwertyuiop1" invalidNetID := route.NetID("12345678901234567890qwertyuiopqwertyuiop1")
validGroupHA1 := routeGroupHA1 validGroupHA1 := routeGroupHA1
validGroupHA2 := routeGroupHA2 validGroupHA2 := routeGroupHA2

View File

@ -451,7 +451,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
} }
account.GroupsG = nil account.GroupsG = nil
account.Routes = make(map[string]*route.Route, len(account.RoutesG)) account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG { for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy() account.Routes[route.ID] = route.Copy()
} }

View File

@ -2,8 +2,6 @@ package server
import ( import (
"fmt" "fmt"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
@ -12,6 +10,9 @@ import (
"testing" "testing"
"time" "time"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -75,9 +76,9 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
} }
account.Users[user.Id] = user account.Users[user.Id] = user
route := &route2.Route{ route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n), ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
Description: "base route", Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n), NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
Network: netip.MustParsePrefix(netIP.String() + "/24"), Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network, NetworkType: route2.IPv4Network,
Metric: 9999, Metric: 9999,

22
route/hauniqueid.go Normal file
View File

@ -0,0 +1,22 @@
package route
import "strings"
type HAUniqueID string
// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address
func GetHAUniqueID(input *Route) HAUniqueID {
return HAUniqueID(string(input.NetID) + "-" + input.Network.String())
}
func (id HAUniqueID) String() string {
return string(id)
}
// NetID returns the Network ID from the HAUniqueID
func (id HAUniqueID) NetID() NetID {
if i := strings.LastIndex(string(id), "-"); i != -1 {
return NetID(id[:i])
}
return NetID(id)
}

View File

@ -36,6 +36,12 @@ const (
IPv6Network IPv6Network
) )
type ID string
type NetID string
type HAMap map[HAUniqueID][]*Route
// NetworkType route network type // NetworkType route network type
type NetworkType int type NetworkType int
@ -65,11 +71,11 @@ func ToPrefixType(prefix string) NetworkType {
// Route represents a route // Route represents a route
type Route struct { type Route struct {
ID string `gorm:"primaryKey"` ID ID `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
Network netip.Prefix `gorm:"serializer:json"` Network netip.Prefix `gorm:"serializer:json"`
NetID string NetID NetID
Description string Description string
Peer string Peer string
PeerGroups []string `gorm:"serializer:json"` PeerGroups []string `gorm:"serializer:json"`
@ -165,8 +171,3 @@ func compareList(list, other []string) bool {
return true return true
} }
// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address
func GetHAUniqueID(input *Route) string {
return input.NetID + "-" + input.Network.String()
}