mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
Refactor Route IDs (#1891)
This commit is contained in:
parent
6a4935139d
commit
4e7c17756c
@ -112,7 +112,7 @@ type Engine struct {
|
|||||||
TURNs []*stun.URI
|
TURNs []*stun.URI
|
||||||
|
|
||||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||||
clientRoutes map[string][]*route.Route
|
clientRoutes route.HAMap
|
||||||
|
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
@ -736,9 +736,9 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
for _, protoRoute := range protoRoutes {
|
for _, protoRoute := range protoRoutes {
|
||||||
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
|
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
|
||||||
convertedRoute := &route.Route{
|
convertedRoute := &route.Route{
|
||||||
ID: protoRoute.ID,
|
ID: route.ID(protoRoute.ID),
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
NetID: protoRoute.NetID,
|
NetID: route.NetID(protoRoute.NetID),
|
||||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||||
Peer: protoRoute.Peer,
|
Peer: protoRoute.Peer,
|
||||||
Metric: int(protoRoute.Metric),
|
Metric: int(protoRoute.Metric),
|
||||||
@ -1238,18 +1238,15 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the current routes from the route map
|
// GetClientRoutes returns the current routes from the route map
|
||||||
func (e *Engine) GetClientRoutes() map[string][]*route.Route {
|
func (e *Engine) GetClientRoutes() route.HAMap {
|
||||||
return e.clientRoutes
|
return e.clientRoutes
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||||
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route {
|
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||||
routes := make(map[string][]*route.Route, len(e.clientRoutes))
|
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
|
||||||
for id, v := range e.clientRoutes {
|
for id, v := range e.clientRoutes {
|
||||||
if i := strings.LastIndex(id, "-"); i != -1 {
|
routes[id.NetID()] = v
|
||||||
id = id[:i]
|
|
||||||
}
|
|
||||||
routes[id] = v
|
|
||||||
}
|
}
|
||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
@ -578,7 +578,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
}{}
|
}{}
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||||
input.inputSerial = updateSerial
|
input.inputSerial = updateSerial
|
||||||
input.inputRoutes = newRoutes
|
input.inputRoutes = newRoutes
|
||||||
return nil, nil, testCase.inputErr
|
return nil, nil, testCase.inputErr
|
||||||
@ -743,7 +743,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
mockRouteManager := &routemanager.MockManager{
|
mockRouteManager := &routemanager.MockManager{
|
||||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -33,7 +33,7 @@ type clientNetwork struct {
|
|||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
routes map[string]*route.Route
|
routes map[route.ID]*route.Route
|
||||||
routeUpdate chan routesUpdate
|
routeUpdate chan routesUpdate
|
||||||
peerStateUpdate chan struct{}
|
peerStateUpdate chan struct{}
|
||||||
routePeersNotifiers map[string]chan struct{}
|
routePeersNotifiers map[string]chan struct{}
|
||||||
@ -50,7 +50,7 @@ func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, st
|
|||||||
stop: cancel,
|
stop: cancel,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
routes: make(map[string]*route.Route),
|
routes: make(map[route.ID]*route.Route),
|
||||||
routePeersNotifiers: make(map[string]chan struct{}),
|
routePeersNotifiers: make(map[string]chan struct{}),
|
||||||
routeUpdate: make(chan routesUpdate),
|
routeUpdate: make(chan routesUpdate),
|
||||||
peerStateUpdate: make(chan struct{}),
|
peerStateUpdate: make(chan struct{}),
|
||||||
@ -59,8 +59,8 @@ func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, st
|
|||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||||
routePeerStatuses := make(map[string]routerPeerStatus)
|
routePeerStatuses := make(map[route.ID]routerPeerStatus)
|
||||||
for _, r := range c.routes {
|
for _, r := range c.routes {
|
||||||
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
|
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -90,12 +90,12 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
|||||||
// * Latency: Routes with lower latency are prioritized.
|
// * Latency: Routes with lower latency are prioritized.
|
||||||
//
|
//
|
||||||
// It returns the ID of the selected optimal route.
|
// It returns the ID of the selected optimal route.
|
||||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
|
||||||
chosen := ""
|
chosen := route.ID("")
|
||||||
chosenScore := float64(0)
|
chosenScore := float64(0)
|
||||||
currScore := float64(0)
|
currScore := float64(0)
|
||||||
|
|
||||||
currID := ""
|
currID := route.ID("")
|
||||||
if c.chosenRoute != nil {
|
if c.chosenRoute != nil {
|
||||||
currID = c.chosenRoute.ID
|
currID = c.chosenRoute.ID
|
||||||
}
|
}
|
||||||
@ -295,7 +295,7 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) handleUpdate(update routesUpdate) {
|
func (c *clientNetwork) handleUpdate(update routesUpdate) {
|
||||||
updateMap := make(map[string]*route.Route)
|
updateMap := make(map[route.ID]*route.Route)
|
||||||
|
|
||||||
for _, r := range update.routes {
|
for _, r := range update.routes {
|
||||||
updateMap[r.ID] = r
|
updateMap[r.ID] = r
|
||||||
|
@ -12,21 +12,21 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
statuses map[string]routerPeerStatus
|
statuses map[route.ID]routerPeerStatus
|
||||||
expectedRouteID string
|
expectedRouteID route.ID
|
||||||
currentRoute string
|
currentRoute route.ID
|
||||||
existingRoutes map[string]*route.Route
|
existingRoutes map[route.ID]*route.Route
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "one route",
|
name: "one route",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
direct: true,
|
direct: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@ -38,14 +38,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "one connected routes with relayed and direct",
|
name: "one connected routes with relayed and direct",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: true,
|
relayed: true,
|
||||||
direct: true,
|
direct: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@ -57,14 +57,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "one connected routes with relayed and no direct",
|
name: "one connected routes with relayed and no direct",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: true,
|
relayed: true,
|
||||||
direct: false,
|
direct: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@ -76,14 +76,14 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no connected peers",
|
name: "no connected peers",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: false,
|
connected: false,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
direct: false,
|
direct: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@ -95,7 +95,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple connected peers with different metrics",
|
name: "multiple connected peers with different metrics",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
@ -107,7 +107,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
direct: true,
|
direct: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: 9000,
|
Metric: 9000,
|
||||||
@ -124,7 +124,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple connected peers with one relayed",
|
name: "multiple connected peers with one relayed",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
@ -136,7 +136,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
direct: true,
|
direct: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@ -153,7 +153,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple connected peers with one direct",
|
name: "multiple connected peers with one direct",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
@ -165,7 +165,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
direct: false,
|
direct: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@ -182,7 +182,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple connected peers with different latencies",
|
name: "multiple connected peers with different latencies",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
latency: 300 * time.Millisecond,
|
latency: 300 * time.Millisecond,
|
||||||
@ -192,7 +192,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
latency: 10 * time.Millisecond,
|
latency: 10 * time.Millisecond,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@ -209,7 +209,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "should ignore routes with latency 0",
|
name: "should ignore routes with latency 0",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
latency: 0 * time.Millisecond,
|
latency: 0 * time.Millisecond,
|
||||||
@ -219,7 +219,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
latency: 10 * time.Millisecond,
|
latency: 10 * time.Millisecond,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@ -236,7 +236,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "current route with similar score and similar but slightly worse latency should not change",
|
name: "current route with similar score and similar but slightly worse latency should not change",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
@ -250,7 +250,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
latency: 10 * time.Millisecond,
|
latency: 10 * time.Millisecond,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@ -267,7 +267,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "current route with bad score should be changed to route with better score",
|
name: "current route with bad score should be changed to route with better score",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
@ -281,7 +281,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
latency: 10 * time.Millisecond,
|
latency: 10 * time.Millisecond,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
@ -298,7 +298,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "current chosen route doesn't exist anymore",
|
name: "current chosen route doesn't exist anymore",
|
||||||
statuses: map[string]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
"route1": {
|
"route1": {
|
||||||
connected: true,
|
connected: true,
|
||||||
relayed: false,
|
relayed: false,
|
||||||
@ -312,7 +312,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
latency: 10 * time.Millisecond,
|
latency: 10 * time.Millisecond,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
existingRoutes: map[string]*route.Route{
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
Metric: route.MaxMetric,
|
Metric: route.MaxMetric,
|
||||||
|
@ -29,8 +29,8 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
|||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||||
TriggerSelection(map[string][]*route.Route)
|
TriggerSelection(route.HAMap)
|
||||||
GetRouteSelector() *routeselector.RouteSelector
|
GetRouteSelector() *routeselector.RouteSelector
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
@ -43,7 +43,7 @@ type DefaultManager struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
clientNetworks map[string]*clientNetwork
|
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||||
routeSelector *routeselector.RouteSelector
|
routeSelector *routeselector.RouteSelector
|
||||||
serverRouter serverRouter
|
serverRouter serverRouter
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
@ -57,7 +57,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
|||||||
dm := &DefaultManager{
|
dm := &DefaultManager{
|
||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
clientNetworks: make(map[string]*clientNetwork),
|
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||||
routeSelector: routeselector.NewRouteSelector(),
|
routeSelector: routeselector.NewRouteSelector(),
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
@ -122,7 +122,7 @@ func (m *DefaultManager) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("not updating routes as context is closed")
|
log.Infof("not updating routes as context is closed")
|
||||||
@ -164,12 +164,12 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetClientRoutes returns the client routes
|
// GetClientRoutes returns the client routes
|
||||||
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork {
|
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
|
||||||
return m.clientNetworks
|
return m.clientNetworks
|
||||||
}
|
}
|
||||||
|
|
||||||
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
||||||
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
|
func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||||
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) {
|
func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
|
||||||
for id, client := range m.clientNetworks {
|
for id, client := range m.clientNetworks {
|
||||||
if _, ok := networks[id]; !ok {
|
if _, ok := networks[id]; !ok {
|
||||||
log.Debugf("Stopping client network watcher, %s", id)
|
log.Debugf("Stopping client network watcher, %s", id)
|
||||||
@ -200,7 +200,7 @@ func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks route.HAMap) {
|
||||||
// removing routes that do not exist as per the update from the Management service.
|
// removing routes that do not exist as per the update from the Management service.
|
||||||
m.stopObsoleteClients(networks)
|
m.stopObsoleteClients(networks)
|
||||||
|
|
||||||
@ -219,15 +219,15 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
|
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
|
||||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
newClientRoutesIDMap := make(route.HAMap)
|
||||||
newServerRoutesMap := make(map[string]*route.Route)
|
newServerRoutesMap := make(map[route.ID]*route.Route)
|
||||||
ownNetworkIDs := make(map[string]bool)
|
ownNetworkIDs := make(map[route.HAUniqueID]bool)
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
for _, newRoute := range newRoutes {
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
haID := route.GetHAUniqueID(newRoute)
|
||||||
if newRoute.Peer == m.pubKey {
|
if newRoute.Peer == m.pubKey {
|
||||||
ownNetworkIDs[networkID] = true
|
ownNetworkIDs[haID] = true
|
||||||
// only linux is supported for now
|
// only linux is supported for now
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
|
||||||
@ -238,12 +238,12 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*r
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, newRoute := range newRoutes {
|
for _, newRoute := range newRoutes {
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
haID := route.GetHAUniqueID(newRoute)
|
||||||
if !ownNetworkIDs[networkID] {
|
if !ownNetworkIDs[haID] {
|
||||||
if !isPrefixSupported(newRoute.Network) {
|
if !isPrefixSupported(newRoute.Network) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,8 +14,8 @@ import (
|
|||||||
|
|
||||||
// MockManager is the mock instance of a route manager
|
// MockManager is the mock instance of a route manager
|
||||||
type MockManager struct {
|
type MockManager struct {
|
||||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
|
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||||
TriggerSelectionFunc func(map[string][]*route.Route)
|
TriggerSelectionFunc func(haMap route.HAMap)
|
||||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
}
|
}
|
||||||
@ -30,14 +30,14 @@ func (m *MockManager) InitialRouteRange() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
|
||||||
if m.UpdateRoutesFunc != nil {
|
if m.UpdateRoutesFunc != nil {
|
||||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||||
}
|
}
|
||||||
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) {
|
func (m *MockManager) TriggerSelection(networks route.HAMap) {
|
||||||
if m.TriggerSelectionFunc != nil {
|
if m.TriggerSelectionFunc != nil {
|
||||||
m.TriggerSelectionFunc(networks)
|
m.TriggerSelectionFunc(networks)
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,7 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
|
|||||||
n.initialRouteRangers = nets
|
n.initialRouteRangers = nets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) {
|
func (n *notifier) onNewRoutes(idMap route.HAMap) {
|
||||||
newNets := make([]string, 0)
|
newNets := make([]string, 0)
|
||||||
for _, routes := range idMap {
|
for _, routes := range idMap {
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
|
@ -3,7 +3,7 @@ package routemanager
|
|||||||
import "github.com/netbirdio/netbird/route"
|
import "github.com/netbirdio/netbird/route"
|
||||||
|
|
||||||
type serverRouter interface {
|
type serverRouter interface {
|
||||||
updateRoutes(map[string]*route.Route) error
|
updateRoutes(map[route.ID]*route.Route) error
|
||||||
removeFromServerNetwork(*route.Route) error
|
removeFromServerNetwork(*route.Route) error
|
||||||
cleanUp()
|
cleanUp()
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@ import (
|
|||||||
type defaultServerRouter struct {
|
type defaultServerRouter struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
routes map[string]*route.Route
|
routes map[route.ID]*route.Route
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
@ -28,15 +28,15 @@ type defaultServerRouter struct {
|
|||||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
||||||
return &defaultServerRouter{
|
return &defaultServerRouter{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
routes: make(map[string]*route.Route),
|
routes: make(map[route.ID]*route.Route),
|
||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
||||||
serverRoutesToRemove := make([]string, 0)
|
serverRoutesToRemove := make([]route.ID, 0)
|
||||||
|
|
||||||
for routeID := range m.routes {
|
for routeID := range m.routes {
|
||||||
update, found := routesMap[routeID]
|
update, found := routesMap[routeID]
|
||||||
@ -168,7 +168,7 @@ func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair,
|
|||||||
return firewall.RouterPair{}, err
|
return firewall.RouterPair{}, err
|
||||||
}
|
}
|
||||||
return firewall.RouterPair{
|
return firewall.RouterPair{
|
||||||
ID: route.ID,
|
ID: string(route.ID),
|
||||||
Source: parsed.String(),
|
Source: parsed.String(),
|
||||||
Destination: route.Network.Masked().String(),
|
Destination: route.Network.Masked().String(),
|
||||||
Masquerade: route.Masquerade,
|
Masquerade: route.Masquerade,
|
||||||
|
@ -12,22 +12,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RouteSelector struct {
|
type RouteSelector struct {
|
||||||
selectedRoutes map[string]struct{}
|
selectedRoutes map[route.NetID]struct{}
|
||||||
selectAll bool
|
selectAll bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRouteSelector() *RouteSelector {
|
func NewRouteSelector() *RouteSelector {
|
||||||
return &RouteSelector{
|
return &RouteSelector{
|
||||||
selectedRoutes: map[string]struct{}{},
|
selectedRoutes: map[route.NetID]struct{}{},
|
||||||
// default selects all routes
|
// default selects all routes
|
||||||
selectAll: true,
|
selectAll: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectRoutes updates the selected routes based on the provided route IDs.
|
// SelectRoutes updates the selected routes based on the provided route IDs.
|
||||||
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error {
|
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
|
||||||
if !appendRoute {
|
if !appendRoute {
|
||||||
rs.selectedRoutes = map[string]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var multiErr *multierror.Error
|
var multiErr *multierror.Error
|
||||||
@ -51,15 +51,15 @@ func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRout
|
|||||||
// SelectAllRoutes sets the selector to select all routes.
|
// SelectAllRoutes sets the selector to select all routes.
|
||||||
func (rs *RouteSelector) SelectAllRoutes() {
|
func (rs *RouteSelector) SelectAllRoutes() {
|
||||||
rs.selectAll = true
|
rs.selectAll = true
|
||||||
rs.selectedRoutes = map[string]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeselectRoutes removes specific routes from the selection.
|
// DeselectRoutes removes specific routes from the selection.
|
||||||
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
||||||
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error {
|
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
rs.selectedRoutes = map[string]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
for _, route := range allRoutes {
|
for _, route := range allRoutes {
|
||||||
rs.selectedRoutes[route] = struct{}{}
|
rs.selectedRoutes[route] = struct{}{}
|
||||||
}
|
}
|
||||||
@ -85,11 +85,11 @@ func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) err
|
|||||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||||
func (rs *RouteSelector) DeselectAllRoutes() {
|
func (rs *RouteSelector) DeselectAllRoutes() {
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
rs.selectedRoutes = map[string]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSelected checks if a specific route is selected.
|
// IsSelected checks if a specific route is selected.
|
||||||
func (rs *RouteSelector) IsSelected(routeID string) bool {
|
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -98,18 +98,14 @@ func (rs *RouteSelector) IsSelected(routeID string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FilterSelected removes unselected routes from the provided map.
|
// FilterSelected removes unselected routes from the provided map.
|
||||||
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route {
|
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
return maps.Clone(routes)
|
return maps.Clone(routes)
|
||||||
}
|
}
|
||||||
|
|
||||||
filtered := map[string][]*route.Route{}
|
filtered := route.HAMap{}
|
||||||
for id, rt := range routes {
|
for id, rt := range routes {
|
||||||
netID := id
|
if rs.IsSelected(id.NetID()) {
|
||||||
if i := strings.LastIndex(id, "-"); i != -1 {
|
|
||||||
netID = id[:i]
|
|
||||||
}
|
|
||||||
if rs.IsSelected(netID) {
|
|
||||||
filtered[id] = rt
|
filtered[id] = rt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,53 +12,53 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRouteSelector_SelectRoutes(t *testing.T) {
|
func TestRouteSelector_SelectRoutes(t *testing.T) {
|
||||||
allRoutes := []string{"route1", "route2", "route3"}
|
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialSelected []string
|
initialSelected []route.NetID
|
||||||
|
|
||||||
selectRoutes []string
|
selectRoutes []route.NetID
|
||||||
append bool
|
append bool
|
||||||
|
|
||||||
wantSelected []string
|
wantSelected []route.NetID
|
||||||
wantError bool
|
wantError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Select specific routes, initial all selected",
|
name: "Select specific routes, initial all selected",
|
||||||
selectRoutes: []string{"route1", "route2"},
|
selectRoutes: []route.NetID{"route1", "route2"},
|
||||||
wantSelected: []string{"route1", "route2"},
|
wantSelected: []route.NetID{"route1", "route2"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Select specific routes, initial all deselected",
|
name: "Select specific routes, initial all deselected",
|
||||||
initialSelected: []string{},
|
initialSelected: []route.NetID{},
|
||||||
selectRoutes: []string{"route1", "route2"},
|
selectRoutes: []route.NetID{"route1", "route2"},
|
||||||
wantSelected: []string{"route1", "route2"},
|
wantSelected: []route.NetID{"route1", "route2"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Select specific routes with initial selection",
|
name: "Select specific routes with initial selection",
|
||||||
initialSelected: []string{"route1"},
|
initialSelected: []route.NetID{"route1"},
|
||||||
selectRoutes: []string{"route2", "route3"},
|
selectRoutes: []route.NetID{"route2", "route3"},
|
||||||
wantSelected: []string{"route2", "route3"},
|
wantSelected: []route.NetID{"route2", "route3"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Select non-existing route",
|
name: "Select non-existing route",
|
||||||
selectRoutes: []string{"route1", "route4"},
|
selectRoutes: []route.NetID{"route1", "route4"},
|
||||||
wantSelected: []string{"route1"},
|
wantSelected: []route.NetID{"route1"},
|
||||||
wantError: true,
|
wantError: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Append route with initial selection",
|
name: "Append route with initial selection",
|
||||||
initialSelected: []string{"route1"},
|
initialSelected: []route.NetID{"route1"},
|
||||||
selectRoutes: []string{"route2"},
|
selectRoutes: []route.NetID{"route2"},
|
||||||
append: true,
|
append: true,
|
||||||
wantSelected: []string{"route1", "route2"},
|
wantSelected: []route.NetID{"route1", "route2"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Append route without initial selection",
|
name: "Append route without initial selection",
|
||||||
selectRoutes: []string{"route2"},
|
selectRoutes: []route.NetID{"route2"},
|
||||||
append: true,
|
append: true,
|
||||||
wantSelected: []string{"route2"},
|
wantSelected: []route.NetID{"route2"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,32 +86,32 @@ func TestRouteSelector_SelectRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
|
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
|
||||||
allRoutes := []string{"route1", "route2", "route3"}
|
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialSelected []string
|
initialSelected []route.NetID
|
||||||
|
|
||||||
wantSelected []string
|
wantSelected []route.NetID
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Initial all selected",
|
name: "Initial all selected",
|
||||||
wantSelected: []string{"route1", "route2", "route3"},
|
wantSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial all deselected",
|
name: "Initial all deselected",
|
||||||
initialSelected: []string{},
|
initialSelected: []route.NetID{},
|
||||||
wantSelected: []string{"route1", "route2", "route3"},
|
wantSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial some selected",
|
name: "Initial some selected",
|
||||||
initialSelected: []string{"route1"},
|
initialSelected: []route.NetID{"route1"},
|
||||||
wantSelected: []string{"route1", "route2", "route3"},
|
wantSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial all selected",
|
name: "Initial all selected",
|
||||||
initialSelected: []string{"route1", "route2", "route3"},
|
initialSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
wantSelected: []string{"route1", "route2", "route3"},
|
wantSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,39 +134,39 @@ func TestRouteSelector_SelectAllRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteSelector_DeselectRoutes(t *testing.T) {
|
func TestRouteSelector_DeselectRoutes(t *testing.T) {
|
||||||
allRoutes := []string{"route1", "route2", "route3"}
|
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialSelected []string
|
initialSelected []route.NetID
|
||||||
|
|
||||||
deselectRoutes []string
|
deselectRoutes []route.NetID
|
||||||
|
|
||||||
wantSelected []string
|
wantSelected []route.NetID
|
||||||
wantError bool
|
wantError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Deselect specific routes, initial all selected",
|
name: "Deselect specific routes, initial all selected",
|
||||||
deselectRoutes: []string{"route1", "route2"},
|
deselectRoutes: []route.NetID{"route1", "route2"},
|
||||||
wantSelected: []string{"route3"},
|
wantSelected: []route.NetID{"route3"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Deselect specific routes, initial all deselected",
|
name: "Deselect specific routes, initial all deselected",
|
||||||
initialSelected: []string{},
|
initialSelected: []route.NetID{},
|
||||||
deselectRoutes: []string{"route1", "route2"},
|
deselectRoutes: []route.NetID{"route1", "route2"},
|
||||||
wantSelected: []string{},
|
wantSelected: []route.NetID{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Deselect specific routes with initial selection",
|
name: "Deselect specific routes with initial selection",
|
||||||
initialSelected: []string{"route1", "route2"},
|
initialSelected: []route.NetID{"route1", "route2"},
|
||||||
deselectRoutes: []string{"route1", "route3"},
|
deselectRoutes: []route.NetID{"route1", "route3"},
|
||||||
wantSelected: []string{"route2"},
|
wantSelected: []route.NetID{"route2"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Deselect non-existing route",
|
name: "Deselect non-existing route",
|
||||||
initialSelected: []string{"route1", "route2"},
|
initialSelected: []route.NetID{"route1", "route2"},
|
||||||
deselectRoutes: []string{"route1", "route4"},
|
deselectRoutes: []route.NetID{"route1", "route4"},
|
||||||
wantSelected: []string{"route2"},
|
wantSelected: []route.NetID{"route2"},
|
||||||
wantError: true,
|
wantError: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -195,32 +195,32 @@ func TestRouteSelector_DeselectRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteSelector_DeselectAll(t *testing.T) {
|
func TestRouteSelector_DeselectAll(t *testing.T) {
|
||||||
allRoutes := []string{"route1", "route2", "route3"}
|
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialSelected []string
|
initialSelected []route.NetID
|
||||||
|
|
||||||
wantSelected []string
|
wantSelected []route.NetID
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Initial all selected",
|
name: "Initial all selected",
|
||||||
wantSelected: []string{},
|
wantSelected: []route.NetID{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial all deselected",
|
name: "Initial all deselected",
|
||||||
initialSelected: []string{},
|
initialSelected: []route.NetID{},
|
||||||
wantSelected: []string{},
|
wantSelected: []route.NetID{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial some selected",
|
name: "Initial some selected",
|
||||||
initialSelected: []string{"route1", "route2"},
|
initialSelected: []route.NetID{"route1", "route2"},
|
||||||
wantSelected: []string{},
|
wantSelected: []route.NetID{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Initial all selected",
|
name: "Initial all selected",
|
||||||
initialSelected: []string{"route1", "route2", "route3"},
|
initialSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
wantSelected: []string{},
|
wantSelected: []route.NetID{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ func TestRouteSelector_DeselectAll(t *testing.T) {
|
|||||||
func TestRouteSelector_IsSelected(t *testing.T) {
|
func TestRouteSelector_IsSelected(t *testing.T) {
|
||||||
rs := routeselector.NewRouteSelector()
|
rs := routeselector.NewRouteSelector()
|
||||||
|
|
||||||
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
|
err := rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, []route.NetID{"route1", "route2", "route3"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, rs.IsSelected("route1"))
|
assert.True(t, rs.IsSelected("route1"))
|
||||||
@ -257,10 +257,10 @@ func TestRouteSelector_IsSelected(t *testing.T) {
|
|||||||
func TestRouteSelector_FilterSelected(t *testing.T) {
|
func TestRouteSelector_FilterSelected(t *testing.T) {
|
||||||
rs := routeselector.NewRouteSelector()
|
rs := routeselector.NewRouteSelector()
|
||||||
|
|
||||||
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
|
err := rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, []route.NetID{"route1", "route2", "route3"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
routes := map[string][]*route.Route{
|
routes := route.HAMap{
|
||||||
"route1-10.0.0.0/8": {},
|
"route1-10.0.0.0/8": {},
|
||||||
"route2-192.168.0.0/16": {},
|
"route2-192.168.0.0/16": {},
|
||||||
"route3-172.16.0.0/12": {},
|
"route3-172.16.0.0/12": {},
|
||||||
@ -268,7 +268,7 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
|
|||||||
|
|
||||||
filtered := rs.FilterSelected(routes)
|
filtered := rs.FilterSelected(routes)
|
||||||
|
|
||||||
assert.Equal(t, map[string][]*route.Route{
|
assert.Equal(t, route.HAMap{
|
||||||
"route1-10.0.0.0/8": {},
|
"route1-10.0.0.0/8": {},
|
||||||
"route2-192.168.0.0/16": {},
|
"route2-192.168.0.0/16": {},
|
||||||
}, filtered)
|
}, filtered)
|
||||||
|
@ -9,10 +9,11 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type selectRoute struct {
|
type selectRoute struct {
|
||||||
NetID string
|
NetID route.NetID
|
||||||
Network netip.Prefix
|
Network netip.Prefix
|
||||||
Selected bool
|
Selected bool
|
||||||
}
|
}
|
||||||
@ -60,7 +61,7 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
|
|||||||
var pbRoutes []*proto.Route
|
var pbRoutes []*proto.Route
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
pbRoutes = append(pbRoutes, &proto.Route{
|
pbRoutes = append(pbRoutes, &proto.Route{
|
||||||
ID: route.NetID,
|
ID: string(route.NetID),
|
||||||
Network: route.Network.String(),
|
Network: route.Network.String(),
|
||||||
Selected: route.Selected,
|
Selected: route.Selected,
|
||||||
})
|
})
|
||||||
@ -81,7 +82,8 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest)
|
|||||||
if req.GetAll() {
|
if req.GetAll() {
|
||||||
routeSelector.SelectAllRoutes()
|
routeSelector.SelectAllRoutes()
|
||||||
} else {
|
} else {
|
||||||
if err := routeSelector.SelectRoutes(req.GetRouteIDs(), req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
|
routes := toNetIDs(req.GetRouteIDs())
|
||||||
|
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
|
||||||
return nil, fmt.Errorf("select routes: %w", err)
|
return nil, fmt.Errorf("select routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -100,7 +102,8 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques
|
|||||||
if req.GetAll() {
|
if req.GetAll() {
|
||||||
routeSelector.DeselectAllRoutes()
|
routeSelector.DeselectAllRoutes()
|
||||||
} else {
|
} else {
|
||||||
if err := routeSelector.DeselectRoutes(req.GetRouteIDs(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
|
routes := toNetIDs(req.GetRouteIDs())
|
||||||
|
if err := routeSelector.DeselectRoutes(routes, maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil {
|
||||||
return nil, fmt.Errorf("deselect routes: %w", err)
|
return nil, fmt.Errorf("deselect routes: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -108,3 +111,11 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques
|
|||||||
|
|
||||||
return &proto.SelectRoutesResponse{}, nil
|
return &proto.SelectRoutesResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toNetIDs(routes []string) []route.NetID {
|
||||||
|
var netIDs []route.NetID
|
||||||
|
for _, rt := range routes {
|
||||||
|
netIDs = append(netIDs, route.NetID(rt))
|
||||||
|
}
|
||||||
|
return netIDs
|
||||||
|
}
|
||||||
|
@ -100,10 +100,10 @@ type AccountManager interface {
|
|||||||
SavePolicy(accountID, userID string, policy *Policy) error
|
SavePolicy(accountID, userID string, policy *Policy) error
|
||||||
DeletePolicy(accountID, policyID, userID string) error
|
DeletePolicy(accountID, policyID, userID string) error
|
||||||
ListPolicies(accountID, userID string) ([]*Policy, error)
|
ListPolicies(accountID, userID string) ([]*Policy, error)
|
||||||
GetRoute(accountID, routeID, userID string) (*route.Route, error)
|
GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||||
SaveRoute(accountID, userID string, route *route.Route) error
|
SaveRoute(accountID, userID string, route *route.Route) error
|
||||||
DeleteRoute(accountID, routeID, userID string) error
|
DeleteRoute(accountID string, routeID route.ID, userID string) error
|
||||||
ListRoutes(accountID, userID string) ([]*route.Route, error)
|
ListRoutes(accountID, userID string) ([]*route.Route, error)
|
||||||
GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||||
@ -229,7 +229,7 @@ type Account struct {
|
|||||||
Groups map[string]*nbgroup.Group `gorm:"-"`
|
Groups map[string]*nbgroup.Group `gorm:"-"`
|
||||||
GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
||||||
Routes map[string]*route.Route `gorm:"-"`
|
Routes map[route.ID]*route.Route `gorm:"-"`
|
||||||
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"`
|
NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"`
|
||||||
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
@ -266,7 +266,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
|
|||||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
|
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
|
||||||
peerRoutesMembership := make(lookupMap)
|
peerRoutesMembership := make(lookupMap)
|
||||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||||
peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{}
|
peerRoutesMembership[string(route.GetHAUniqueID(r))] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
groupListMap := a.getPeerGroups(peerID)
|
groupListMap := a.getPeerGroups(peerID)
|
||||||
@ -284,7 +284,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
|
|||||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
|
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
|
||||||
var filteredRoutes []*route.Route
|
var filteredRoutes []*route.Route
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
_, found := peerMemberships[route.GetHAUniqueID(r)]
|
_, found := peerMemberships[string(route.GetHAUniqueID(r))]
|
||||||
if !found {
|
if !found {
|
||||||
filteredRoutes = append(filteredRoutes, r)
|
filteredRoutes = append(filteredRoutes, r)
|
||||||
}
|
}
|
||||||
@ -323,7 +323,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
|
|||||||
return enabledRoutes, disabledRoutes
|
return enabledRoutes, disabledRoutes
|
||||||
}
|
}
|
||||||
|
|
||||||
seenRoute := make(map[string]struct{})
|
seenRoute := make(map[route.ID]struct{})
|
||||||
|
|
||||||
takeRoute := func(r *route.Route, id string) {
|
takeRoute := func(r *route.Route, id string) {
|
||||||
if _, ok := seenRoute[r.ID]; ok {
|
if _, ok := seenRoute[r.ID]; ok {
|
||||||
@ -354,7 +354,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
|
|||||||
newPeerRoute := r.Copy()
|
newPeerRoute := r.Copy()
|
||||||
newPeerRoute.Peer = id
|
newPeerRoute.Peer = id
|
||||||
newPeerRoute.PeerGroups = nil
|
newPeerRoute.PeerGroups = nil
|
||||||
newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map
|
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
|
||||||
takeRoute(newPeerRoute, id)
|
takeRoute(newPeerRoute, id)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -693,7 +693,7 @@ func (a *Account) Copy() *Account {
|
|||||||
policies = append(policies, policy.Copy())
|
policies = append(policies, policy.Copy())
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := map[string]*route.Route{}
|
routes := map[route.ID]*route.Route{}
|
||||||
for id, r := range a.Routes {
|
for id, r := range a.Routes {
|
||||||
routes[id] = r.Copy()
|
routes[id] = r.Copy()
|
||||||
}
|
}
|
||||||
@ -1946,7 +1946,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
|
|||||||
network := NewNetwork()
|
network := NewNetwork()
|
||||||
peers := make(map[string]*nbpeer.Peer)
|
peers := make(map[string]*nbpeer.Peer)
|
||||||
users := make(map[string]*User)
|
users := make(map[string]*User)
|
||||||
routes := make(map[string]*route.Route)
|
routes := make(map[route.ID]*route.Route)
|
||||||
setupKeys := map[string]*SetupKey{}
|
setupKeys := map[string]*SetupKey{}
|
||||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||||
users[userID] = NewOwnerUser(userID)
|
users[userID] = NewOwnerUser(userID)
|
||||||
|
@ -1408,7 +1408,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Routes: map[string]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"route-1": {
|
"route-1": {
|
||||||
ID: "route-1",
|
ID: "route-1",
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
@ -1437,12 +1437,12 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
|||||||
routes := account.GetRoutesByPrefix(prefix)
|
routes := account.GetRoutesByPrefix(prefix)
|
||||||
|
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeIDs := make(map[string]struct{}, 2)
|
routeIDs := make(map[route.ID]struct{}, 2)
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
routeIDs[r.ID] = struct{}{}
|
routeIDs[r.ID] = struct{}{}
|
||||||
}
|
}
|
||||||
assert.Contains(t, routeIDs, "route-1")
|
assert.Contains(t, routeIDs, route.ID("route-1"))
|
||||||
assert.Contains(t, routeIDs, "route-2")
|
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_GetRoutesToSync(t *testing.T) {
|
func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||||
@ -1459,7 +1459,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
|||||||
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
||||||
},
|
},
|
||||||
Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
|
Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
|
||||||
Routes: map[string]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"route-1": {
|
"route-1": {
|
||||||
ID: "route-1",
|
ID: "route-1",
|
||||||
Network: prefix,
|
Network: prefix,
|
||||||
@ -1502,12 +1502,12 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
|||||||
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||||
|
|
||||||
assert.Len(t, routes, 2)
|
assert.Len(t, routes, 2)
|
||||||
routeIDs := make(map[string]struct{}, 2)
|
routeIDs := make(map[route.ID]struct{}, 2)
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
routeIDs[r.ID] = struct{}{}
|
routeIDs[r.ID] = struct{}{}
|
||||||
}
|
}
|
||||||
assert.Contains(t, routeIDs, "route-2")
|
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||||
assert.Contains(t, routeIDs, "route-3")
|
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||||
|
|
||||||
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||||
|
|
||||||
@ -1573,7 +1573,7 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
SourcePostureChecks: make([]string, 0),
|
SourcePostureChecks: make([]string, 0),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Routes: map[string]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"route1": {
|
"route1": {
|
||||||
ID: "route1",
|
ID: "route1",
|
||||||
PeerGroups: []string{},
|
PeerGroups: []string{},
|
||||||
|
@ -242,7 +242,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
|
|||||||
for _, r := range account.Routes {
|
for _, r := range account.Routes {
|
||||||
for _, g := range r.Groups {
|
for _, g := range r.Groups {
|
||||||
if g == groupID {
|
if g == groupID {
|
||||||
return &GroupLinkError{"route", r.NetID}
|
return &GroupLinkError{"route", string(r.NetID)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -107,7 +107,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
newRoute, err := h.accountManager.CreateRoute(
|
newRoute, err := h.accountManager.CreateRoute(
|
||||||
account.Id, newPrefix.String(), peerId, peerGroupIds,
|
account.Id, newPrefix.String(), peerId, peerGroupIds,
|
||||||
req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
|
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
@ -135,7 +135,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
_, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
@ -185,9 +185,9 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
newRoute := &route.Route{
|
newRoute := &route.Route{
|
||||||
ID: routeID,
|
ID: route.ID(routeID),
|
||||||
Network: newPrefix,
|
Network: newPrefix,
|
||||||
NetID: req.NetworkId,
|
NetID: route.NetID(req.NetworkId),
|
||||||
NetworkType: prefixType,
|
NetworkType: prefixType,
|
||||||
Masquerade: req.Masquerade,
|
Masquerade: req.Masquerade,
|
||||||
Metric: req.Metric,
|
Metric: req.Metric,
|
||||||
@ -230,7 +230,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteRoute(account.Id, routeID, user.Id)
|
err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
@ -254,7 +254,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
|
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
|
||||||
return
|
return
|
||||||
@ -265,9 +265,9 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
func toRouteResponse(serverRoute *route.Route) *api.Route {
|
func toRouteResponse(serverRoute *route.Route) *api.Route {
|
||||||
route := &api.Route{
|
route := &api.Route{
|
||||||
Id: serverRoute.ID,
|
Id: string(serverRoute.ID),
|
||||||
Description: serverRoute.Description,
|
Description: serverRoute.Description,
|
||||||
NetworkId: serverRoute.NetID,
|
NetworkId: string(serverRoute.NetID),
|
||||||
Enabled: serverRoute.Enabled,
|
Enabled: serverRoute.Enabled,
|
||||||
Peer: &serverRoute.Peer,
|
Peer: &serverRoute.Peer,
|
||||||
Network: serverRoute.Network.String(),
|
Network: serverRoute.Network.String(),
|
||||||
|
@ -82,7 +82,7 @@ var testingAccount = &server.Account{
|
|||||||
func initRoutesTestData() *RoutesHandler {
|
func initRoutesTestData() *RoutesHandler {
|
||||||
return &RoutesHandler{
|
return &RoutesHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) {
|
GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) {
|
||||||
if routeID == existingRouteID {
|
if routeID == existingRouteID {
|
||||||
return baseExistingRoute, nil
|
return baseExistingRoute, nil
|
||||||
}
|
}
|
||||||
@ -93,7 +93,7 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
}
|
}
|
||||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||||
},
|
},
|
||||||
CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
|
CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
|
||||||
if peerID == notFoundPeerID {
|
if peerID == notFoundPeerID {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||||
}
|
}
|
||||||
@ -120,7 +120,7 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
DeleteRouteFunc: func(_ string, routeID string, _ string) error {
|
DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error {
|
||||||
if routeID != existingRouteID {
|
if routeID != existingRouteID {
|
||||||
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
|
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
|
|||||||
SourcePostureChecks: []string{"1"},
|
SourcePostureChecks: []string{"1"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Routes: map[string]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"1": {
|
"1": {
|
||||||
ID: "1",
|
ID: "1",
|
||||||
PeerGroups: make([]string, 1),
|
PeerGroups: make([]string, 1),
|
||||||
@ -151,7 +151,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Routes: map[string]*route.Route{
|
Routes: map[route.ID]*route.Route{
|
||||||
"1": {
|
"1": {
|
||||||
ID: "1",
|
ID: "1",
|
||||||
PeerGroups: make([]string, 1),
|
PeerGroups: make([]string, 1),
|
||||||
|
@ -22,76 +22,76 @@ type MockAccountManager struct {
|
|||||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
ListUsersFunc func(accountID string) ([]*server.User, error)
|
ListUsersFunc func(accountID string) ([]*server.User, error)
|
||||||
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
|
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
|
||||||
DeletePeerFunc func(accountID, peerKey, userID string) error
|
DeletePeerFunc func(accountID, peerKey, userID string) error
|
||||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||||
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
|
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||||
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
|
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
|
||||||
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
|
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
|
||||||
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
|
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
|
||||||
SaveGroupFunc func(accountID, userID string, group *group.Group) error
|
SaveGroupFunc func(accountID, userID string, group *group.Group) error
|
||||||
DeleteGroupFunc func(accountID, userId, groupID string) error
|
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||||
ListGroupsFunc func(accountID string) ([]*group.Group, error)
|
ListGroupsFunc func(accountID string) ([]*group.Group, error)
|
||||||
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
||||||
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
||||||
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
||||||
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
|
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
|
||||||
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
|
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
|
||||||
DeletePolicyFunc func(accountID, policyID, userID string) error
|
DeletePolicyFunc func(accountID, policyID, userID string) error
|
||||||
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
|
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
|
||||||
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
|
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
|
||||||
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||||
MarkPATUsedFunc func(pat string) error
|
MarkPATUsedFunc func(pat string) error
|
||||||
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
|
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
|
||||||
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
||||||
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||||
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
|
GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
SaveRouteFunc func(accountID, userID string, route *route.Route) error
|
SaveRouteFunc func(accountID string, userID string, route *route.Route) error
|
||||||
DeleteRouteFunc func(accountID, routeID, userID string) error
|
DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error
|
||||||
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
||||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||||
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||||
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||||
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||||
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||||
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||||
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||||
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||||
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
||||||
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||||
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||||
DeleteAccountFunc func(accountID, userID string) error
|
DeleteAccountFunc func(accountID, userID string) error
|
||||||
GetDNSDomainFunc func() string
|
GetDNSDomainFunc func() string
|
||||||
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||||
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
|
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
|
||||||
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
||||||
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||||
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||||
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
|
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||||
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
|
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||||
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
||||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||||
HasConnectedChannelFunc func(peerID string) bool
|
HasConnectedChannelFunc func(peerID string) bool
|
||||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||||
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
|
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
|
||||||
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
|
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
|
||||||
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
|
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
|
||||||
GetIdpManagerFunc func() idp.Manager
|
GetIdpManagerFunc func() idp.Manager
|
||||||
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
|
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
|
||||||
GroupValidationFunc func(accountId string, groups []string) (bool, error)
|
GroupValidationFunc func(accountId string, groups []string) (bool, error)
|
||||||
}
|
}
|
||||||
@ -399,15 +399,15 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) CreateRoute(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
func (am *MockAccountManager) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
||||||
if am.CreateRouteFunc != nil {
|
if am.CreateRouteFunc != nil {
|
||||||
return am.CreateRouteFunc(accountID, network, peerID, peerGroups, description, netID, masquerade, metric, groups, enabled, userID)
|
return am.CreateRouteFunc(accountID, prefix, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoute mock implementation of GetRoute from server.AccountManager interface
|
// GetRoute mock implementation of GetRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
|
func (am *MockAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||||
if am.GetRouteFunc != nil {
|
if am.GetRouteFunc != nil {
|
||||||
return am.GetRouteFunc(accountID, routeID, userID)
|
return am.GetRouteFunc(accountID, routeID, userID)
|
||||||
}
|
}
|
||||||
@ -415,7 +415,7 @@ func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*rout
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SaveRoute mock implementation of SaveRoute from server.AccountManager interface
|
// SaveRoute mock implementation of SaveRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.Route) error {
|
func (am *MockAccountManager) SaveRoute(accountID string, userID string, route *route.Route) error {
|
||||||
if am.SaveRouteFunc != nil {
|
if am.SaveRouteFunc != nil {
|
||||||
return am.SaveRouteFunc(accountID, userID, route)
|
return am.SaveRouteFunc(accountID, userID, route)
|
||||||
}
|
}
|
||||||
@ -423,7 +423,7 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
|
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
func (am *MockAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
|
||||||
if am.DeleteRouteFunc != nil {
|
if am.DeleteRouteFunc != nil {
|
||||||
return am.DeleteRouteFunc(accountID, routeID, userID)
|
return am.DeleteRouteFunc(accountID, routeID, userID)
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// GetRoute gets a route object from account and route IDs
|
// GetRoute gets a route object from account and route IDs
|
||||||
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
|
func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||||
func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID, routeID string, peerGroupIDs []string, prefix netip.Prefix) error {
|
func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix) error {
|
||||||
// routes can have both peer and peer_groups
|
// routes can have both peer and peer_groups
|
||||||
routesWithPrefix := account.GetRoutesByPrefix(prefix)
|
routesWithPrefix := account.GetRoutesByPrefix(prefix)
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
|
|||||||
}
|
}
|
||||||
|
|
||||||
if prefixRoute.Peer != "" {
|
if prefixRoute.Peer != "" {
|
||||||
seenPeers[prefixRoute.ID] = true
|
seenPeers[string(prefixRoute.ID)] = true
|
||||||
}
|
}
|
||||||
for _, groupID := range prefixRoute.PeerGroups {
|
for _, groupID := range prefixRoute.PeerGroups {
|
||||||
seenPeerGroups[groupID] = true
|
seenPeerGroups[groupID] = true
|
||||||
@ -114,7 +114,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateRoute creates and saves a new route
|
// CreateRoute creates and saves a new route
|
||||||
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
var newRoute route.Route
|
var newRoute route.Route
|
||||||
newRoute.ID = xid.New().String()
|
newRoute.ID = route.ID(xid.New().String())
|
||||||
|
|
||||||
prefixType, newPrefix, err := route.ParseNetwork(network)
|
prefixType, newPrefix, err := route.ParseNetwork(network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -154,7 +154,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
|||||||
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||||
}
|
}
|
||||||
|
|
||||||
if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" {
|
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,7 +175,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
|||||||
newRoute.Groups = groups
|
newRoute.Groups = groups
|
||||||
|
|
||||||
if account.Routes == nil {
|
if account.Routes == nil {
|
||||||
account.Routes = make(map[string]*route.Route)
|
account.Routes = make(map[route.ID]*route.Route)
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Routes[newRoute.ID] = &newRoute
|
account.Routes[newRoute.ID] = &newRoute
|
||||||
@ -187,7 +187,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
|||||||
|
|
||||||
am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
am.StoreEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta())
|
am.StoreEvent(userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||||
|
|
||||||
return &newRoute, nil
|
return &newRoute, nil
|
||||||
}
|
}
|
||||||
@ -209,7 +209,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
|||||||
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||||
}
|
}
|
||||||
|
|
||||||
if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" {
|
if utf8.RuneCountInString(string(routeToSave.NetID)) > route.MaxNetIDChar || routeToSave.NetID == "" {
|
||||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,13 +248,13 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
|||||||
|
|
||||||
am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
am.StoreEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
am.StoreEvent(userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoute deletes route with routeID
|
// DeleteRoute deletes route with routeID
|
||||||
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
|
||||||
unlock := am.Store.AcquireAccountLock(accountID)
|
unlock := am.Store.AcquireAccountLock(accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@ -274,7 +274,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta())
|
am.StoreEvent(userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||||
|
|
||||||
am.updateAccountPeers(account)
|
am.updateAccountPeers(account)
|
||||||
|
|
||||||
@ -310,8 +310,8 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
|
|||||||
|
|
||||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||||
return &proto.Route{
|
return &proto.Route{
|
||||||
ID: route.ID,
|
ID: string(route.ID),
|
||||||
NetID: route.NetID,
|
NetID: string(route.NetID),
|
||||||
Network: route.Network.String(),
|
Network: route.Network.String(),
|
||||||
NetworkType: int64(route.NetworkType),
|
NetworkType: int64(route.NetworkType),
|
||||||
Peer: route.Peer,
|
Peer: route.Peer,
|
||||||
|
@ -40,7 +40,7 @@ const (
|
|||||||
func TestCreateRoute(t *testing.T) {
|
func TestCreateRoute(t *testing.T) {
|
||||||
type input struct {
|
type input struct {
|
||||||
network string
|
network string
|
||||||
netID string
|
netID route.NetID
|
||||||
peerKey string
|
peerKey string
|
||||||
peerGroupIDs []string
|
peerGroupIDs []string
|
||||||
description string
|
description string
|
||||||
@ -382,8 +382,8 @@ func TestSaveRoute(t *testing.T) {
|
|||||||
invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34")
|
invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34")
|
||||||
validMetric := 1000
|
validMetric := 1000
|
||||||
invalidMetric := 99999
|
invalidMetric := 99999
|
||||||
validNetID := "12345678901234567890qw"
|
validNetID := route.NetID("12345678901234567890qw")
|
||||||
invalidNetID := "12345678901234567890qwertyuiopqwertyuiop1"
|
invalidNetID := route.NetID("12345678901234567890qwertyuiopqwertyuiop1")
|
||||||
validGroupHA1 := routeGroupHA1
|
validGroupHA1 := routeGroupHA1
|
||||||
validGroupHA2 := routeGroupHA2
|
validGroupHA2 := routeGroupHA2
|
||||||
|
|
||||||
|
@ -451,7 +451,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
|||||||
}
|
}
|
||||||
account.GroupsG = nil
|
account.GroupsG = nil
|
||||||
|
|
||||||
account.Routes = make(map[string]*route.Route, len(account.RoutesG))
|
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
||||||
for _, route := range account.RoutesG {
|
for _, route := range account.RoutesG {
|
||||||
account.Routes[route.ID] = route.Copy()
|
account.Routes[route.ID] = route.Copy()
|
||||||
}
|
}
|
||||||
|
@ -2,8 +2,6 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@ -12,6 +10,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -75,9 +76,9 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
|
|||||||
}
|
}
|
||||||
account.Users[user.Id] = user
|
account.Users[user.Id] = user
|
||||||
route := &route2.Route{
|
route := &route2.Route{
|
||||||
ID: fmt.Sprintf("network-id-%d", n),
|
ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
|
||||||
Description: "base route",
|
Description: "base route",
|
||||||
NetID: fmt.Sprintf("network-id-%d", n),
|
NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
|
||||||
Network: netip.MustParsePrefix(netIP.String() + "/24"),
|
Network: netip.MustParsePrefix(netIP.String() + "/24"),
|
||||||
NetworkType: route2.IPv4Network,
|
NetworkType: route2.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
|
22
route/hauniqueid.go
Normal file
22
route/hauniqueid.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package route
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
type HAUniqueID string
|
||||||
|
|
||||||
|
// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address
|
||||||
|
func GetHAUniqueID(input *Route) HAUniqueID {
|
||||||
|
return HAUniqueID(string(input.NetID) + "-" + input.Network.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (id HAUniqueID) String() string {
|
||||||
|
return string(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetID returns the Network ID from the HAUniqueID
|
||||||
|
func (id HAUniqueID) NetID() NetID {
|
||||||
|
if i := strings.LastIndex(string(id), "-"); i != -1 {
|
||||||
|
return NetID(id[:i])
|
||||||
|
}
|
||||||
|
return NetID(id)
|
||||||
|
}
|
@ -36,6 +36,12 @@ const (
|
|||||||
IPv6Network
|
IPv6Network
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ID string
|
||||||
|
|
||||||
|
type NetID string
|
||||||
|
|
||||||
|
type HAMap map[HAUniqueID][]*Route
|
||||||
|
|
||||||
// NetworkType route network type
|
// NetworkType route network type
|
||||||
type NetworkType int
|
type NetworkType int
|
||||||
|
|
||||||
@ -65,11 +71,11 @@ func ToPrefixType(prefix string) NetworkType {
|
|||||||
|
|
||||||
// Route represents a route
|
// Route represents a route
|
||||||
type Route struct {
|
type Route struct {
|
||||||
ID string `gorm:"primaryKey"`
|
ID ID `gorm:"primaryKey"`
|
||||||
// AccountID is a reference to Account that this object belongs
|
// AccountID is a reference to Account that this object belongs
|
||||||
AccountID string `gorm:"index"`
|
AccountID string `gorm:"index"`
|
||||||
Network netip.Prefix `gorm:"serializer:json"`
|
Network netip.Prefix `gorm:"serializer:json"`
|
||||||
NetID string
|
NetID NetID
|
||||||
Description string
|
Description string
|
||||||
Peer string
|
Peer string
|
||||||
PeerGroups []string `gorm:"serializer:json"`
|
PeerGroups []string `gorm:"serializer:json"`
|
||||||
@ -165,8 +171,3 @@ func compareList(list, other []string) bool {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHAUniqueID returns a highly available route ID by combining Network ID and Network range address
|
|
||||||
func GetHAUniqueID(input *Route) string {
|
|
||||||
return input.NetID + "-" + input.Network.String()
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user