1
0
mirror of https://github.com/netbirdio/netbird.git synced 2025-04-27 21:09:09 +02:00

Refactor Route IDs ()

This commit is contained in:
Viktor Liu 2024-05-06 14:47:49 +02:00 committed by GitHub
parent 6a4935139d
commit 4e7c17756c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 320 additions and 292 deletions

View File

@ -112,7 +112,7 @@ type Engine struct {
TURNs []*stun.URI TURNs []*stun.URI
// clientRoutes is the most recent list of clientRoutes received from the Management Service // clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes map[string][]*route.Route clientRoutes route.HAMap
cancel context.CancelFunc cancel context.CancelFunc
@ -736,9 +736,9 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
for _, protoRoute := range protoRoutes { for _, protoRoute := range protoRoutes {
_, prefix, _ := route.ParseNetwork(protoRoute.Network) _, prefix, _ := route.ParseNetwork(protoRoute.Network)
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: protoRoute.ID, ID: route.ID(protoRoute.ID),
Network: prefix, Network: prefix,
NetID: protoRoute.NetID, NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer, Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric), Metric: int(protoRoute.Metric),
@ -1238,18 +1238,15 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
} }
// GetClientRoutes returns the current routes from the route map // GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() map[string][]*route.Route { func (e *Engine) GetClientRoutes() route.HAMap {
return e.clientRoutes return e.clientRoutes
} }
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route { func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
routes := make(map[string][]*route.Route, len(e.clientRoutes)) routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes { for id, v := range e.clientRoutes {
if i := strings.LastIndex(id, "-"); i != -1 { routes[id.NetID()] = v
id = id[:i]
}
routes[id] = v
} }
return routes return routes
} }

View File

@ -578,7 +578,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{} }{}
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
input.inputSerial = updateSerial input.inputSerial = updateSerial
input.inputRoutes = newRoutes input.inputRoutes = newRoutes
return nil, nil, testCase.inputErr return nil, nil, testCase.inputErr
@ -743,7 +743,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
return nil, nil, nil return nil, nil, nil
}, },
} }

View File

@ -33,7 +33,7 @@ type clientNetwork struct {
stop context.CancelFunc stop context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
routes map[string]*route.Route routes map[route.ID]*route.Route
routeUpdate chan routesUpdate routeUpdate chan routesUpdate
peerStateUpdate chan struct{} peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{} routePeersNotifiers map[string]chan struct{}
@ -50,7 +50,7 @@ func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, st
stop: cancel, stop: cancel,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
routes: make(map[string]*route.Route), routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
@ -59,8 +59,8 @@ func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, st
return client return client
} }
func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
routePeerStatuses := make(map[string]routerPeerStatus) routePeerStatuses := make(map[route.ID]routerPeerStatus)
for _, r := range c.routes { for _, r := range c.routes {
peerStatus, err := c.statusRecorder.GetPeer(r.Peer) peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
if err != nil { if err != nil {
@ -90,12 +90,12 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
// * Latency: Routes with lower latency are prioritized. // * Latency: Routes with lower latency are prioritized.
// //
// It returns the ID of the selected optimal route. // It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
chosen := "" chosen := route.ID("")
chosenScore := float64(0) chosenScore := float64(0)
currScore := float64(0) currScore := float64(0)
currID := "" currID := route.ID("")
if c.chosenRoute != nil { if c.chosenRoute != nil {
currID = c.chosenRoute.ID currID = c.chosenRoute.ID
} }
@ -295,7 +295,7 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
} }
func (c *clientNetwork) handleUpdate(update routesUpdate) { func (c *clientNetwork) handleUpdate(update routesUpdate) {
updateMap := make(map[string]*route.Route) updateMap := make(map[route.ID]*route.Route)
for _, r := range update.routes { for _, r := range update.routes {
updateMap[r.ID] = r updateMap[r.ID] = r

View File

@ -12,21 +12,21 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
statuses map[string]routerPeerStatus statuses map[route.ID]routerPeerStatus
expectedRouteID string expectedRouteID route.ID
currentRoute string currentRoute route.ID
existingRoutes map[string]*route.Route existingRoutes map[route.ID]*route.Route
}{ }{
{ {
name: "one route", name: "one route",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
direct: true, direct: true,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -38,14 +38,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "one connected routes with relayed and direct", name: "one connected routes with relayed and direct",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: true, relayed: true,
direct: true, direct: true,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -57,14 +57,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "one connected routes with relayed and no direct", name: "one connected routes with relayed and no direct",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: true, relayed: true,
direct: false, direct: false,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -76,14 +76,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "no connected peers", name: "no connected peers",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: false, connected: false,
relayed: false, relayed: false,
direct: false, direct: false,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -95,7 +95,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "multiple connected peers with different metrics", name: "multiple connected peers with different metrics",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -107,7 +107,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
direct: true, direct: true,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: 9000, Metric: 9000,
@ -124,7 +124,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "multiple connected peers with one relayed", name: "multiple connected peers with one relayed",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -136,7 +136,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
direct: true, direct: true,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -153,7 +153,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "multiple connected peers with one direct", name: "multiple connected peers with one direct",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -165,7 +165,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
direct: false, direct: false,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -182,7 +182,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "multiple connected peers with different latencies", name: "multiple connected peers with different latencies",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
latency: 300 * time.Millisecond, latency: 300 * time.Millisecond,
@ -192,7 +192,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
latency: 10 * time.Millisecond, latency: 10 * time.Millisecond,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -209,7 +209,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "should ignore routes with latency 0", name: "should ignore routes with latency 0",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
latency: 0 * time.Millisecond, latency: 0 * time.Millisecond,
@ -219,7 +219,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
latency: 10 * time.Millisecond, latency: 10 * time.Millisecond,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -236,7 +236,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "current route with similar score and similar but slightly worse latency should not change", name: "current route with similar score and similar but slightly worse latency should not change",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -250,7 +250,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
latency: 10 * time.Millisecond, latency: 10 * time.Millisecond,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -267,7 +267,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "current route with bad score should be changed to route with better score", name: "current route with bad score should be changed to route with better score",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -281,7 +281,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
latency: 10 * time.Millisecond, latency: 10 * time.Millisecond,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,
@ -298,7 +298,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
{ {
name: "current chosen route doesn't exist anymore", name: "current chosen route doesn't exist anymore",
statuses: map[string]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
"route1": { "route1": {
connected: true, connected: true,
relayed: false, relayed: false,
@ -312,7 +312,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
latency: 10 * time.Millisecond, latency: 10 * time.Millisecond,
}, },
}, },
existingRoutes: map[string]*route.Route{ existingRoutes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Metric: route.MaxMetric, Metric: route.MaxMetric,

View File

@ -29,8 +29,8 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(map[string][]*route.Route) TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector GetRouteSelector() *routeselector.RouteSelector
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
@ -43,7 +43,7 @@ type DefaultManager struct {
ctx context.Context ctx context.Context
stop context.CancelFunc stop context.CancelFunc
mux sync.Mutex mux sync.Mutex
clientNetworks map[string]*clientNetwork clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector routeSelector *routeselector.RouteSelector
serverRouter serverRouter serverRouter serverRouter
statusRecorder *peer.Status statusRecorder *peer.Status
@ -57,7 +57,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
dm := &DefaultManager{ dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
clientNetworks: make(map[string]*clientNetwork), clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(), routeSelector: routeselector.NewRouteSelector(),
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
@ -122,7 +122,7 @@ func (m *DefaultManager) Stop() {
} }
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not updating routes as context is closed") log.Infof("not updating routes as context is closed")
@ -164,12 +164,12 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
} }
// GetClientRoutes returns the client routes // GetClientRoutes returns the client routes
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork { func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
return m.clientNetworks return m.clientNetworks
} }
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones // TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) { func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
@ -190,7 +190,7 @@ func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
} }
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list // stopObsoleteClients stops the client network watcher for the networks that are not in the new list
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) { func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
for id, client := range m.clientNetworks { for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok { if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id) log.Debugf("Stopping client network watcher, %s", id)
@ -200,7 +200,7 @@ func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route)
} }
} }
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks route.HAMap) {
// removing routes that do not exist as per the update from the Management service. // removing routes that do not exist as per the update from the Management service.
m.stopObsoleteClients(networks) m.stopObsoleteClients(networks)
@ -219,15 +219,15 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
} }
} }
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) { func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
newClientRoutesIDMap := make(map[string][]*route.Route) newClientRoutesIDMap := make(route.HAMap)
newServerRoutesMap := make(map[string]*route.Route) newServerRoutesMap := make(map[route.ID]*route.Route)
ownNetworkIDs := make(map[string]bool) ownNetworkIDs := make(map[route.HAUniqueID]bool)
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute) haID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey { if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true ownNetworkIDs[haID] = true
// only linux is supported for now // only linux is supported for now
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
@ -238,12 +238,12 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*r
} }
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute) haID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] { if !ownNetworkIDs[haID] {
if !isPrefixSupported(newRoute.Network) { if !isPrefixSupported(newRoute.Network) {
continue continue
} }
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
} }
} }

View File

@ -14,8 +14,8 @@ import (
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
type MockManager struct { type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelectionFunc func(map[string][]*route.Route) TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector GetRouteSelectorFunc func() *routeselector.RouteSelector
StopFunc func() StopFunc func()
} }
@ -30,14 +30,14 @@ func (m *MockManager) InitialRouteRange() []string {
} }
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface // UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) { func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
if m.UpdateRoutesFunc != nil { if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes) return m.UpdateRoutesFunc(updateSerial, newRoutes)
} }
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
} }
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) { func (m *MockManager) TriggerSelection(networks route.HAMap) {
if m.TriggerSelectionFunc != nil { if m.TriggerSelectionFunc != nil {
m.TriggerSelectionFunc(networks) m.TriggerSelectionFunc(networks)
} }

View File

@ -36,7 +36,7 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
n.initialRouteRangers = nets n.initialRouteRangers = nets
} }
func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) { func (n *notifier) onNewRoutes(idMap route.HAMap) {
newNets := make([]string, 0) newNets := make([]string, 0)
for _, routes := range idMap { for _, routes := range idMap {
for _, r := range routes { for _, r := range routes {

View File

@ -3,7 +3,7 @@ package routemanager
import "github.com/netbirdio/netbird/route" import "github.com/netbirdio/netbird/route"
type serverRouter interface { type serverRouter interface {
updateRoutes(map[string]*route.Route) error updateRoutes(map[route.ID]*route.Route) error
removeFromServerNetwork(*route.Route) error removeFromServerNetwork(*route.Route) error
cleanUp() cleanUp()
} }

View File

@ -19,7 +19,7 @@ import (
type defaultServerRouter struct { type defaultServerRouter struct {
mux sync.Mutex mux sync.Mutex
ctx context.Context ctx context.Context
routes map[string]*route.Route routes map[route.ID]*route.Route
firewall firewall.Manager firewall firewall.Manager
wgInterface *iface.WGIface wgInterface *iface.WGIface
statusRecorder *peer.Status statusRecorder *peer.Status
@ -28,15 +28,15 @@ type defaultServerRouter struct {
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
return &defaultServerRouter{ return &defaultServerRouter{
ctx: ctx, ctx: ctx,
routes: make(map[string]*route.Route), routes: make(map[route.ID]*route.Route),
firewall: firewall, firewall: firewall,
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
}, nil }, nil
} }
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error { func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
serverRoutesToRemove := make([]string, 0) serverRoutesToRemove := make([]route.ID, 0)
for routeID := range m.routes { for routeID := range m.routes {
update, found := routesMap[routeID] update, found := routesMap[routeID]
@ -168,7 +168,7 @@ func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair,
return firewall.RouterPair{}, err return firewall.RouterPair{}, err
} }
return firewall.RouterPair{ return firewall.RouterPair{
ID: route.ID, ID: string(route.ID),
Source: parsed.String(), Source: parsed.String(),
Destination: route.Network.Masked().String(), Destination: route.Network.Masked().String(),
Masquerade: route.Masquerade, Masquerade: route.Masquerade,

View File

@ -12,22 +12,22 @@ import (
) )
type RouteSelector struct { type RouteSelector struct {
selectedRoutes map[string]struct{} selectedRoutes map[route.NetID]struct{}
selectAll bool selectAll bool
} }
func NewRouteSelector() *RouteSelector { func NewRouteSelector() *RouteSelector {
return &RouteSelector{ return &RouteSelector{
selectedRoutes: map[string]struct{}{}, selectedRoutes: map[route.NetID]struct{}{},
// default selects all routes // default selects all routes
selectAll: true, selectAll: true,
} }
} }
// SelectRoutes updates the selected routes based on the provided route IDs. // SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error { func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
if !appendRoute { if !appendRoute {
rs.selectedRoutes = map[string]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
var multiErr *multierror.Error var multiErr *multierror.Error
@ -51,15 +51,15 @@ func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRout
// SelectAllRoutes sets the selector to select all routes. // SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() { func (rs *RouteSelector) SelectAllRoutes() {
rs.selectAll = true rs.selectAll = true
rs.selectedRoutes = map[string]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
// DeselectRoutes removes specific routes from the selection. // DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode. // If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error { func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
if rs.selectAll { if rs.selectAll {
rs.selectAll = false rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
for _, route := range allRoutes { for _, route := range allRoutes {
rs.selectedRoutes[route] = struct{}{} rs.selectedRoutes[route] = struct{}{}
} }
@ -85,11 +85,11 @@ func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) err
// DeselectAllRoutes deselects all routes, effectively disabling route selection. // DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() { func (rs *RouteSelector) DeselectAllRoutes() {
rs.selectAll = false rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
// IsSelected checks if a specific route is selected. // IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID string) bool { func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
if rs.selectAll { if rs.selectAll {
return true return true
} }
@ -98,18 +98,14 @@ func (rs *RouteSelector) IsSelected(routeID string) bool {
} }
// FilterSelected removes unselected routes from the provided map. // FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route { func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
if rs.selectAll { if rs.selectAll {
return maps.Clone(routes) return maps.Clone(routes)
} }
filtered := map[string][]*route.Route{} filtered := route.HAMap{}
for id, rt := range routes { for id, rt := range routes {
netID := id if rs.IsSelected(id.NetID()) {
if i := strings.LastIndex(id, "-"); i != -1 {
netID = id[:i]
}
if rs.IsSelected(netID) {
filtered[id] = rt filtered[id] = rt
} }
} }

View File

@ -12,53 +12,53 @@ import (
) )
func TestRouteSelector_SelectRoutes(t *testing.T) { func TestRouteSelector_SelectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"} allRoutes := []route.NetID{"route1", "route2", "route3"}
tests := []struct { tests := []struct {
name string name string
initialSelected []string initialSelected []route.NetID
selectRoutes []string selectRoutes []route.NetID
append bool append bool
wantSelected []string wantSelected []route.NetID
wantError bool wantError bool
}{ }{
{ {
name: "Select specific routes, initial all selected", name: "Select specific routes, initial all selected",
selectRoutes: []string{"route1", "route2"}, selectRoutes: []route.NetID{"route1", "route2"},
wantSelected: []string{"route1", "route2"}, wantSelected: []route.NetID{"route1", "route2"},
}, },
{ {
name: "Select specific routes, initial all deselected", name: "Select specific routes, initial all deselected",
initialSelected: []string{}, initialSelected: []route.NetID{},
selectRoutes: []string{"route1", "route2"}, selectRoutes: []route.NetID{"route1", "route2"},
wantSelected: []string{"route1", "route2"}, wantSelected: []route.NetID{"route1", "route2"},
}, },
{ {
name: "Select specific routes with initial selection", name: "Select specific routes with initial selection",
initialSelected: []string{"route1"}, initialSelected: []route.NetID{"route1"},
selectRoutes: []string{"route2", "route3"}, selectRoutes: []route.NetID{"route2", "route3"},
wantSelected: []string{"route2", "route3"}, wantSelected: []route.NetID{"route2", "route3"},
}, },
{ {
name: "Select non-existing route", name: "Select non-existing route",
selectRoutes: []string{"route1", "route4"}, selectRoutes: []route.NetID{"route1", "route4"},
wantSelected: []string{"route1"}, wantSelected: []route.NetID{"route1"},
wantError: true, wantError: true,
}, },
{ {
name: "Append route with initial selection", name: "Append route with initial selection",
initialSelected: []string{"route1"}, initialSelected: []route.NetID{"route1"},
selectRoutes: []string{"route2"}, selectRoutes: []route.NetID{"route2"},
append: true, append: true,
wantSelected: []string{"route1", "route2"}, wantSelected: []route.NetID{"route1", "route2"},
}, },
{ {
name: "Append route without initial selection", name: "Append route without initial selection",
selectRoutes: []string{"route2"}, selectRoutes: []route.NetID{"route2"},
append: true, append: true,
wantSelected: []string{"route2"}, wantSelected: []route.NetID{"route2"},
}, },
} }
@ -86,32 +86,32 @@ func TestRouteSelector_SelectRoutes(t *testing.T) {
} }
func TestRouteSelector_SelectAllRoutes(t *testing.T) { func TestRouteSelector_SelectAllRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"} allRoutes := []route.NetID{"route1", "route2", "route3"}
tests := []struct { tests := []struct {
name string name string
initialSelected []string initialSelected []route.NetID
wantSelected []string wantSelected []route.NetID
}{ }{
{ {
name: "Initial all selected", name: "Initial all selected",
wantSelected: []string{"route1", "route2", "route3"}, wantSelected: []route.NetID{"route1", "route2", "route3"},
}, },
{ {
name: "Initial all deselected", name: "Initial all deselected",
initialSelected: []string{}, initialSelected: []route.NetID{},
wantSelected: []string{"route1", "route2", "route3"}, wantSelected: []route.NetID{"route1", "route2", "route3"},
}, },
{ {
name: "Initial some selected", name: "Initial some selected",
initialSelected: []string{"route1"}, initialSelected: []route.NetID{"route1"},
wantSelected: []string{"route1", "route2", "route3"}, wantSelected: []route.NetID{"route1", "route2", "route3"},
}, },
{ {
name: "Initial all selected", name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"}, initialSelected: []route.NetID{"route1", "route2", "route3"},
wantSelected: []string{"route1", "route2", "route3"}, wantSelected: []route.NetID{"route1", "route2", "route3"},
}, },
} }
@ -134,39 +134,39 @@ func TestRouteSelector_SelectAllRoutes(t *testing.T) {
} }
func TestRouteSelector_DeselectRoutes(t *testing.T) { func TestRouteSelector_DeselectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"} allRoutes := []route.NetID{"route1", "route2", "route3"}
tests := []struct { tests := []struct {
name string name string
initialSelected []string initialSelected []route.NetID
deselectRoutes []string deselectRoutes []route.NetID
wantSelected []string wantSelected []route.NetID
wantError bool wantError bool
}{ }{
{ {
name: "Deselect specific routes, initial all selected", name: "Deselect specific routes, initial all selected",
deselectRoutes: []string{"route1", "route2"}, deselectRoutes: []route.NetID{"route1", "route2"},
wantSelected: []string{"route3"}, wantSelected: []route.NetID{"route3"},
}, },
{ {
name: "Deselect specific routes, initial all deselected", name: "Deselect specific routes, initial all deselected",
initialSelected: []string{}, initialSelected: []route.NetID{},
deselectRoutes: []string{"route1", "route2"}, deselectRoutes: []route.NetID{"route1", "route2"},
wantSelected: []string{}, wantSelected: []route.NetID{},
}, },
{ {
name: "Deselect specific routes with initial selection", name: "Deselect specific routes with initial selection",
initialSelected: []string{"route1", "route2"}, initialSelected: []route.NetID{"route1", "route2"},
deselectRoutes: []string{"route1", "route3"}, deselectRoutes: []route.NetID{"route1", "route3"},
wantSelected: []string{"route2"}, wantSelected: []route.NetID{"route2"},
}, },
{ {
name: "Deselect non-existing route", name: "Deselect non-existing route",
initialSelected: []string{"route1", "route2"}, initialSelected: []route.NetID{"route1", "route2"},
deselectRoutes: []string{"route1", "route4"}, deselectRoutes: []route.NetID{"route1", "route4"},
wantSelected: []string{"route2"}, wantSelected: []route.NetID{"route2"},
wantError: true, wantError: true,
}, },
} }
@ -195,32 +195,32 @@ func TestRouteSelector_DeselectRoutes(t *testing.T) {
} }
func TestRouteSelector_DeselectAll(t *testing.T) { func TestRouteSelector_DeselectAll(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"} allRoutes := []route.NetID{"route1", "route2", "route3"}
tests := []struct { tests := []struct {
name string name string
initialSelected []string initialSelected []route.NetID
wantSelected []string wantSelected []route.NetID
}{ }{
{ {
name: "Initial all selected", name: "Initial all selected",
wantSelected: []string{}, wantSelected: []route.NetID{},
}, },
{ {
name: "Initial all deselected", name: "Initial all deselected",
initialSelected: []string{}, initialSelected: []route.NetID{},
wantSelected: []string{}, wantSelected: []route.NetID{},
}, },
{ {
name: "Initial some selected", name: "Initial some selected",
initialSelected: []string{"route1", "route2"}, initialSelected: []route.NetID{"route1", "route2"},
wantSelected: []string{}, wantSelected: []route.NetID{},
}, },
{ {
name: "Initial all selected", name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"}, initialSelected: []route.NetID{"route1", "route2", "route3"},
wantSelected: []string{}, wantSelected: []route.NetID{},
}, },
} }
@ -245,7 +245,7 @@ func TestRouteSelector_DeselectAll(t *testing.T) {
func TestRouteSelector_IsSelected(t *testing.T) { func TestRouteSelector_IsSelected(t *testing.T) {
rs := routeselector.NewRouteSelector() rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"}) err := rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, []route.NetID{"route1", "route2", "route3"})
require.NoError(t, err) require.NoError(t, err)
assert.True(t, rs.IsSelected("route1")) assert.True(t, rs.IsSelected("route1"))
@ -257,10 +257,10 @@ func TestRouteSelector_IsSelected(t *testing.T) {
func TestRouteSelector_FilterSelected(t *testing.T) { func TestRouteSelector_FilterSelected(t *testing.T) {
rs := routeselector.NewRouteSelector() rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"}) err := rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, []route.NetID{"route1", "route2", "route3"})
require.NoError(t, err) require.NoError(t, err)
routes := map[string][]*route.Route{ routes := route.HAMap{
"route1-10.0.0.0/8": {}, "route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {}, "route2-192.168.0.0/16": {},
"route3-172.16.0.0/12": {}, "route3-172.16.0.0/12": {},
@ -268,7 +268,7 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
filtered := rs.FilterSelected(routes) filtered := rs.FilterSelected(routes)
assert.Equal(t, map[string][]*route.Route{ assert.Equal(t, route.HAMap{
"route1-10.0.0.0/8": {}, "route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {}, "route2-192.168.0.0/16": {},
}, filtered) }, filtered)

View File

@ -9,10 +9,11 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
) )
type selectRoute struct { type selectRoute struct {
NetID string NetID route.NetID
Network netip.Prefix Network netip.Prefix
Selected bool Selected bool
} }
@ -60,7 +61,7 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
var pbRoutes []*proto.Route var pbRoutes []*proto.Route
for _, route := range routes { for _, route := range routes {
pbRoutes = append(pbRoutes, &proto.Route{ pbRoutes = append(pbRoutes, &proto.Route{
ID: route.NetID, ID: string(route.NetID),
Network: route.Network.String(), Network: route.Network.String(),
Selected: route.Selected, Selected: route.Selected,
}) })
@ -81,7 +82,8 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest)
if req.GetAll() { if req.GetAll() {
routeSelector.SelectAllRoutes() routeSelector.SelectAllRoutes()
} else { } else {
if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { routes := toNetIDs(req.GetRouteIDs())
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("select routes: %w", err) return nil, fmt.Errorf("select routes: %w", err)
} }
} }
@ -100,7 +102,8 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques
if req.GetAll() { if req.GetAll() {
routeSelector.DeselectAllRoutes() routeSelector.DeselectAllRoutes()
} else { } else {
if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { routes := toNetIDs(req.GetRouteIDs())
if err := routeSelector.DeselectRoutes(routes, maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err) return nil, fmt.Errorf("deselect routes: %w", err)
} }
} }
@ -108,3 +111,11 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques
return &proto.SelectRoutesResponse{}, nil return &proto.SelectRoutesResponse{}, nil
} }
func toNetIDs(routes []string) []route.NetID {
var netIDs []route.NetID
for _, rt := range routes {
netIDs = append(netIDs, route.NetID(rt))
}
return netIDs
}

View File

@ -100,10 +100,10 @@ type AccountManager interface {
SavePolicy(accountID, userID string, policy *Policy) error SavePolicy(accountID, userID string, policy *Policy) error
DeletePolicy(accountID, policyID, userID string) error DeletePolicy(accountID, policyID, userID string) error
ListPolicies(accountID, userID string) ([]*Policy, error) ListPolicies(accountID, userID string) ([]*Policy, error)
GetRoute(accountID, routeID, userID string) (*route.Route, error) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error)
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error SaveRoute(accountID, userID string, route *route.Route) error
DeleteRoute(accountID, routeID, userID string) error DeleteRoute(accountID string, routeID route.ID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error) ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
@ -229,7 +229,7 @@ type Account struct {
Groups map[string]*nbgroup.Group `gorm:"-"` Groups map[string]*nbgroup.Group `gorm:"-"`
GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
Routes map[string]*route.Route `gorm:"-"` Routes map[route.ID]*route.Route `gorm:"-"`
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"`
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
@ -266,7 +266,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID) routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
peerRoutesMembership := make(lookupMap) peerRoutesMembership := make(lookupMap)
for _, r := range append(routes, peerDisabledRoutes...) { for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{} peerRoutesMembership[string(route.GetHAUniqueID(r))] = struct{}{}
} }
groupListMap := a.getPeerGroups(peerID) groupListMap := a.getPeerGroups(peerID)
@ -284,7 +284,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
var filteredRoutes []*route.Route var filteredRoutes []*route.Route
for _, r := range routes { for _, r := range routes {
_, found := peerMemberships[route.GetHAUniqueID(r)] _, found := peerMemberships[string(route.GetHAUniqueID(r))]
if !found { if !found {
filteredRoutes = append(filteredRoutes, r) filteredRoutes = append(filteredRoutes, r)
} }
@ -323,7 +323,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
return enabledRoutes, disabledRoutes return enabledRoutes, disabledRoutes
} }
seenRoute := make(map[string]struct{}) seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route, id string) { takeRoute := func(r *route.Route, id string) {
if _, ok := seenRoute[r.ID]; ok { if _, ok := seenRoute[r.ID]; ok {
@ -354,7 +354,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
newPeerRoute := r.Copy() newPeerRoute := r.Copy()
newPeerRoute.Peer = id newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil newPeerRoute.PeerGroups = nil
newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
takeRoute(newPeerRoute, id) takeRoute(newPeerRoute, id)
break break
} }
@ -693,7 +693,7 @@ func (a *Account) Copy() *Account {
policies = append(policies, policy.Copy()) policies = append(policies, policy.Copy())
} }
routes := map[string]*route.Route{} routes := map[route.ID]*route.Route{}
for id, r := range a.Routes { for id, r := range a.Routes {
routes[id] = r.Copy() routes[id] = r.Copy()
} }
@ -1946,7 +1946,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
network := NewNetwork() network := NewNetwork()
peers := make(map[string]*nbpeer.Peer) peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*User) users := make(map[string]*User)
routes := make(map[string]*route.Route) routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*SetupKey{} setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup) nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID) users[userID] = NewOwnerUser(userID)

View File

@ -1408,7 +1408,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
account := &Account{ account := &Account{
Routes: map[string]*route.Route{ Routes: map[route.ID]*route.Route{
"route-1": { "route-1": {
ID: "route-1", ID: "route-1",
Network: prefix, Network: prefix,
@ -1437,12 +1437,12 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
routes := account.GetRoutesByPrefix(prefix) routes := account.GetRoutesByPrefix(prefix)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2) routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes { for _, r := range routes {
routeIDs[r.ID] = struct{}{} routeIDs[r.ID] = struct{}{}
} }
assert.Contains(t, routeIDs, "route-1") assert.Contains(t, routeIDs, route.ID("route-1"))
assert.Contains(t, routeIDs, "route-2") assert.Contains(t, routeIDs, route.ID("route-2"))
} }
func TestAccount_GetRoutesToSync(t *testing.T) { func TestAccount_GetRoutesToSync(t *testing.T) {
@ -1459,7 +1459,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
}, },
Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[string]*route.Route{ Routes: map[route.ID]*route.Route{
"route-1": { "route-1": {
ID: "route-1", ID: "route-1",
Network: prefix, Network: prefix,
@ -1502,12 +1502,12 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2) routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes { for _, r := range routes {
routeIDs[r.ID] = struct{}{} routeIDs[r.ID] = struct{}{}
} }
assert.Contains(t, routeIDs, "route-2") assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, "route-3") assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
@ -1573,7 +1573,7 @@ func TestAccount_Copy(t *testing.T) {
SourcePostureChecks: make([]string, 0), SourcePostureChecks: make([]string, 0),
}, },
}, },
Routes: map[string]*route.Route{ Routes: map[route.ID]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
PeerGroups: []string{}, PeerGroups: []string{},

View File

@ -242,7 +242,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
for _, r := range account.Routes { for _, r := range account.Routes {
for _, g := range r.Groups { for _, g := range r.Groups {
if g == groupID { if g == groupID {
return &GroupLinkError{"route", r.NetID} return &GroupLinkError{"route", string(r.NetID)}
} }
} }
} }

View File

@ -107,7 +107,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
newRoute, err := h.accountManager.CreateRoute( newRoute, err := h.accountManager.CreateRoute(
account.Id, newPrefix.String(), peerId, peerGroupIds, account.Id, newPrefix.String(), peerId, peerGroupIds,
req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
) )
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -135,7 +135,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id) _, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
return return
@ -185,9 +185,9 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
} }
newRoute := &route.Route{ newRoute := &route.Route{
ID: routeID, ID: route.ID(routeID),
Network: newPrefix, Network: newPrefix,
NetID: req.NetworkId, NetID: route.NetID(req.NetworkId),
NetworkType: prefixType, NetworkType: prefixType,
Masquerade: req.Masquerade, Masquerade: req.Masquerade,
Metric: req.Metric, Metric: req.Metric,
@ -230,7 +230,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeleteRoute(account.Id, routeID, user.Id) err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
return return
@ -254,7 +254,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id) foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.NotFound, "route not found"), w) util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
return return
@ -265,9 +265,9 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
func toRouteResponse(serverRoute *route.Route) *api.Route { func toRouteResponse(serverRoute *route.Route) *api.Route {
route := &api.Route{ route := &api.Route{
Id: serverRoute.ID, Id: string(serverRoute.ID),
Description: serverRoute.Description, Description: serverRoute.Description,
NetworkId: serverRoute.NetID, NetworkId: string(serverRoute.NetID),
Enabled: serverRoute.Enabled, Enabled: serverRoute.Enabled,
Peer: &serverRoute.Peer, Peer: &serverRoute.Peer,
Network: serverRoute.Network.String(), Network: serverRoute.Network.String(),

View File

@ -82,7 +82,7 @@ var testingAccount = &server.Account{
func initRoutesTestData() *RoutesHandler { func initRoutesTestData() *RoutesHandler {
return &RoutesHandler{ return &RoutesHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) { GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) {
if routeID == existingRouteID { if routeID == existingRouteID {
return baseExistingRoute, nil return baseExistingRoute, nil
} }
@ -93,7 +93,7 @@ func initRoutesTestData() *RoutesHandler {
} }
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
}, },
CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) { CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
if peerID == notFoundPeerID { if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
@ -120,7 +120,7 @@ func initRoutesTestData() *RoutesHandler {
} }
return nil return nil
}, },
DeleteRouteFunc: func(_ string, routeID string, _ string) error { DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error {
if routeID != existingRouteID { if routeID != existingRouteID {
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID) return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
} }

View File

@ -67,7 +67,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
SourcePostureChecks: []string{"1"}, SourcePostureChecks: []string{"1"},
}, },
}, },
Routes: map[string]*route.Route{ Routes: map[route.ID]*route.Route{
"1": { "1": {
ID: "1", ID: "1",
PeerGroups: make([]string, 1), PeerGroups: make([]string, 1),
@ -151,7 +151,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
}, },
}, },
}, },
Routes: map[string]*route.Route{ Routes: map[route.ID]*route.Route{
"1": { "1": {
ID: "1", ID: "1",
PeerGroups: make([]string, 1), PeerGroups: make([]string, 1),

View File

@ -51,10 +51,10 @@ type MockAccountManager struct {
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error SaveRouteFunc func(accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(accountID, routeID, userID string) error DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
@ -399,15 +399,15 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.
} }
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface // CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { func (am *MockAccountManager) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
if am.CreateRouteFunc != nil { if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(accountID, network, peerID, peerGroups, description, netID, masquerade, metric, groups, enabled, userID) return am.CreateRouteFunc(accountID, prefix, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
} }
// GetRoute mock implementation of GetRoute from server.AccountManager interface // GetRoute mock implementation of GetRoute from server.AccountManager interface
func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { func (am *MockAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
if am.GetRouteFunc != nil { if am.GetRouteFunc != nil {
return am.GetRouteFunc(accountID, routeID, userID) return am.GetRouteFunc(accountID, routeID, userID)
} }
@ -415,7 +415,7 @@ func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*rout
} }
// SaveRoute mock implementation of SaveRoute from server.AccountManager interface // SaveRoute mock implementation of SaveRoute from server.AccountManager interface
func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.Route) error { func (am *MockAccountManager) SaveRoute(accountID string, userID string, route *route.Route) error {
if am.SaveRouteFunc != nil { if am.SaveRouteFunc != nil {
return am.SaveRouteFunc(accountID, userID, route) return am.SaveRouteFunc(accountID, userID, route)
} }
@ -423,7 +423,7 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R
} }
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface // DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error { func (am *MockAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
if am.DeleteRouteFunc != nil { if am.DeleteRouteFunc != nil {
return am.DeleteRouteFunc(accountID, routeID, userID) return am.DeleteRouteFunc(accountID, routeID, userID)
} }

View File

@ -13,7 +13,7 @@ import (
) )
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -40,7 +40,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
} }
// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID, routeID string, peerGroupIDs []string, prefix netip.Prefix) error { func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix) error {
// routes can have both peer and peer_groups // routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefix(prefix) routesWithPrefix := account.GetRoutesByPrefix(prefix)
@ -56,7 +56,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
} }
if prefixRoute.Peer != "" { if prefixRoute.Peer != "" {
seenPeers[prefixRoute.ID] = true seenPeers[string(prefixRoute.ID)] = true
} }
for _, groupID := range prefixRoute.PeerGroups { for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true seenPeerGroups[groupID] = true
@ -114,7 +114,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
} }
// CreateRoute creates and saves a new route // CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -131,7 +131,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
} }
var newRoute route.Route var newRoute route.Route
newRoute.ID = xid.New().String() newRoute.ID = route.ID(xid.New().String())
prefixType, newPrefix, err := route.ParseNetwork(network) prefixType, newPrefix, err := route.ParseNetwork(network)
if err != nil { if err != nil {
@ -154,7 +154,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
} }
if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" { if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
@ -175,7 +175,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
newRoute.Groups = groups newRoute.Groups = groups
if account.Routes == nil { if account.Routes == nil {
account.Routes = make(map[string]*route.Route) account.Routes = make(map[route.ID]*route.Route)
} }
account.Routes[newRoute.ID] = &newRoute account.Routes[newRoute.ID] = &newRoute
@ -187,7 +187,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
am.updateAccountPeers(account) am.updateAccountPeers(account)
am.StoreEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
return &newRoute, nil return &newRoute, nil
} }
@ -209,7 +209,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
} }
if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" { if utf8.RuneCountInString(string(routeToSave.NetID)) > route.MaxNetIDChar || routeToSave.NetID == "" {
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
@ -248,13 +248,13 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
am.updateAccountPeers(account) am.updateAccountPeers(account)
am.StoreEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta()) am.StoreEvent(userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
return nil return nil
} }
// DeleteRoute deletes route with routeID // DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error { func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -274,7 +274,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
return err return err
} }
am.StoreEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta()) am.StoreEvent(userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
am.updateAccountPeers(account) am.updateAccountPeers(account)
@ -310,8 +310,8 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
func toProtocolRoute(route *route.Route) *proto.Route { func toProtocolRoute(route *route.Route) *proto.Route {
return &proto.Route{ return &proto.Route{
ID: route.ID, ID: string(route.ID),
NetID: route.NetID, NetID: string(route.NetID),
Network: route.Network.String(), Network: route.Network.String(),
NetworkType: int64(route.NetworkType), NetworkType: int64(route.NetworkType),
Peer: route.Peer, Peer: route.Peer,

View File

@ -40,7 +40,7 @@ const (
func TestCreateRoute(t *testing.T) { func TestCreateRoute(t *testing.T) {
type input struct { type input struct {
network string network string
netID string netID route.NetID
peerKey string peerKey string
peerGroupIDs []string peerGroupIDs []string
description string description string
@ -382,8 +382,8 @@ func TestSaveRoute(t *testing.T) {
invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34") invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34")
validMetric := 1000 validMetric := 1000
invalidMetric := 99999 invalidMetric := 99999
validNetID := "12345678901234567890qw" validNetID := route.NetID("12345678901234567890qw")
invalidNetID := "12345678901234567890qwertyuiopqwertyuiop1" invalidNetID := route.NetID("12345678901234567890qwertyuiopqwertyuiop1")
validGroupHA1 := routeGroupHA1 validGroupHA1 := routeGroupHA1
validGroupHA2 := routeGroupHA2 validGroupHA2 := routeGroupHA2

View File

@ -451,7 +451,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
} }
account.GroupsG = nil account.GroupsG = nil
account.Routes = make(map[string]*route.Route, len(account.RoutesG)) account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG { for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy() account.Routes[route.ID] = route.Copy()
} }

View File

@ -2,8 +2,6 @@ package server
import ( import (
"fmt" "fmt"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
@ -12,6 +10,9 @@ import (
"testing" "testing"
"time" "time"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -75,9 +76,9 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
} }
account.Users[user.Id] = user account.Users[user.Id] = user
route := &route2.Route{ route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n), ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
Description: "base route", Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n), NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
Network: netip.MustParsePrefix(netIP.String() + "/24"), Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network, NetworkType: route2.IPv4Network,
Metric: 9999, Metric: 9999,

22
route/hauniqueid.go Normal file
View File

@ -0,0 +1,22 @@
package route
import "strings"
type HAUniqueID string
// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address
func GetHAUniqueID(input *Route) HAUniqueID {
return HAUniqueID(string(input.NetID) + "-" + input.Network.String())
}
func (id HAUniqueID) String() string {
return string(id)
}
// NetID returns the Network ID from the HAUniqueID
func (id HAUniqueID) NetID() NetID {
if i := strings.LastIndex(string(id), "-"); i != -1 {
return NetID(id[:i])
}
return NetID(id)
}

View File

@ -36,6 +36,12 @@ const (
IPv6Network IPv6Network
) )
type ID string
type NetID string
type HAMap map[HAUniqueID][]*Route
// NetworkType route network type // NetworkType route network type
type NetworkType int type NetworkType int
@ -65,11 +71,11 @@ func ToPrefixType(prefix string) NetworkType {
// Route represents a route // Route represents a route
type Route struct { type Route struct {
ID string `gorm:"primaryKey"` ID ID `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs // AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
Network netip.Prefix `gorm:"serializer:json"` Network netip.Prefix `gorm:"serializer:json"`
NetID string NetID NetID
Description string Description string
Peer string Peer string
PeerGroups []string `gorm:"serializer:json"` PeerGroups []string `gorm:"serializer:json"`
@ -165,8 +171,3 @@ func compareList(list, other []string) bool {
return true return true
} }
// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address
func GetHAUniqueID(input *Route) string {
return input.NetID + "-" + input.Network.String()
}