Refactor Route IDs (#1891)

This commit is contained in:
Viktor Liu
2024-05-06 14:47:49 +02:00
committed by GitHub
parent 6a4935139d
commit 4e7c17756c
25 changed files with 320 additions and 292 deletions

View File

@ -100,10 +100,10 @@ type AccountManager interface {
SavePolicy(accountID, userID string, policy *Policy) error
DeletePolicy(accountID, policyID, userID string) error
ListPolicies(accountID, userID string) ([]*Policy, error)
GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error)
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error
DeleteRoute(accountID, routeID, userID string) error
DeleteRoute(accountID string, routeID route.ID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
@ -229,7 +229,7 @@ type Account struct {
Groups map[string]*nbgroup.Group `gorm:"-"`
GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
Routes map[string]*route.Route `gorm:"-"`
Routes map[route.ID]*route.Route `gorm:"-"`
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"`
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
@ -266,7 +266,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
peerRoutesMembership := make(lookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{}
peerRoutesMembership[string(route.GetHAUniqueID(r))] = struct{}{}
}
groupListMap := a.getPeerGroups(peerID)
@ -284,7 +284,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[route.GetHAUniqueID(r)]
_, found := peerMemberships[string(route.GetHAUniqueID(r))]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
@ -323,7 +323,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[string]struct{})
seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route, id string) {
if _, ok := seenRoute[r.ID]; ok {
@ -354,7 +354,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro
newPeerRoute := r.Copy()
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
takeRoute(newPeerRoute, id)
break
}
@ -693,7 +693,7 @@ func (a *Account) Copy() *Account {
policies = append(policies, policy.Copy())
}
routes := map[string]*route.Route{}
routes := map[route.ID]*route.Route{}
for id, r := range a.Routes {
routes[id] = r.Copy()
}
@ -1946,7 +1946,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
network := NewNetwork()
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*User)
routes := make(map[string]*route.Route)
routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID)

View File

@ -1408,7 +1408,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
t.Fatal(err)
}
account := &Account{
Routes: map[string]*route.Route{
Routes: map[route.ID]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
@ -1437,12 +1437,12 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
routes := account.GetRoutesByPrefix(prefix)
assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2)
routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, "route-1")
assert.Contains(t, routeIDs, "route-2")
assert.Contains(t, routeIDs, route.ID("route-1"))
assert.Contains(t, routeIDs, route.ID("route-2"))
}
func TestAccount_GetRoutesToSync(t *testing.T) {
@ -1459,7 +1459,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
},
Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[string]*route.Route{
Routes: map[route.ID]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
@ -1502,12 +1502,12 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2)
routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, "route-2")
assert.Contains(t, routeIDs, "route-3")
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
@ -1573,7 +1573,7 @@ func TestAccount_Copy(t *testing.T) {
SourcePostureChecks: make([]string, 0),
},
},
Routes: map[string]*route.Route{
Routes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
PeerGroups: []string{},

View File

@ -242,7 +242,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
for _, r := range account.Routes {
for _, g := range r.Groups {
if g == groupID {
return &GroupLinkError{"route", r.NetID}
return &GroupLinkError{"route", string(r.NetID)}
}
}
}

View File

@ -107,7 +107,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
newRoute, err := h.accountManager.CreateRoute(
account.Id, newPrefix.String(), peerId, peerGroupIds,
req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
)
if err != nil {
util.WriteError(err, w)
@ -135,7 +135,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
_, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
if err != nil {
util.WriteError(err, w)
return
@ -185,9 +185,9 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
}
newRoute := &route.Route{
ID: routeID,
ID: route.ID(routeID),
Network: newPrefix,
NetID: req.NetworkId,
NetID: route.NetID(req.NetworkId),
NetworkType: prefixType,
Masquerade: req.Masquerade,
Metric: req.Metric,
@ -230,7 +230,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteRoute(account.Id, routeID, user.Id)
err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id)
if err != nil {
util.WriteError(err, w)
return
@ -254,7 +254,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
return
}
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
if err != nil {
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
return
@ -265,9 +265,9 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
func toRouteResponse(serverRoute *route.Route) *api.Route {
route := &api.Route{
Id: serverRoute.ID,
Id: string(serverRoute.ID),
Description: serverRoute.Description,
NetworkId: serverRoute.NetID,
NetworkId: string(serverRoute.NetID),
Enabled: serverRoute.Enabled,
Peer: &serverRoute.Peer,
Network: serverRoute.Network.String(),

View File

@ -82,7 +82,7 @@ var testingAccount = &server.Account{
func initRoutesTestData() *RoutesHandler {
return &RoutesHandler{
accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) {
GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) {
if routeID == existingRouteID {
return baseExistingRoute, nil
}
@ -93,7 +93,7 @@ func initRoutesTestData() *RoutesHandler {
}
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
},
CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
@ -120,7 +120,7 @@ func initRoutesTestData() *RoutesHandler {
}
return nil
},
DeleteRouteFunc: func(_ string, routeID string, _ string) error {
DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error {
if routeID != existingRouteID {
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
}

View File

@ -67,7 +67,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
SourcePostureChecks: []string{"1"},
},
},
Routes: map[string]*route.Route{
Routes: map[route.ID]*route.Route{
"1": {
ID: "1",
PeerGroups: make([]string, 1),
@ -151,7 +151,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
},
},
},
Routes: map[string]*route.Route{
Routes: map[route.ID]*route.Route{
"1": {
ID: "1",
PeerGroups: make([]string, 1),

View File

@ -22,76 +22,76 @@ type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(accountID, userID string, group *group.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
DeleteRuleFunc func(accountID, ruleID, userID string) error
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
DeletePolicyFunc func(accountID, policyID, userID string) error
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error
DeleteRouteFunc func(accountID, routeID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(accountID, userID string) error
GetDNSDomainFunc func() string
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(accountID, userID string, group *group.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
DeleteRuleFunc func(accountID, ruleID, userID string) error
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
DeletePolicyFunc func(accountID, policyID, userID string) error
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(accountID, userID string) error
GetDNSDomainFunc func() string
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
GroupValidationFunc func(accountId string, groups []string) (bool, error)
}
@ -399,15 +399,15 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.
}
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
func (am *MockAccountManager) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(accountID, network, peerID, peerGroups, description, netID, masquerade, metric, groups, enabled, userID)
return am.CreateRouteFunc(accountID, prefix, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}
// GetRoute mock implementation of GetRoute from server.AccountManager interface
func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
func (am *MockAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
if am.GetRouteFunc != nil {
return am.GetRouteFunc(accountID, routeID, userID)
}
@ -415,7 +415,7 @@ func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*rout
}
// SaveRoute mock implementation of SaveRoute from server.AccountManager interface
func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.Route) error {
func (am *MockAccountManager) SaveRoute(accountID string, userID string, route *route.Route) error {
if am.SaveRouteFunc != nil {
return am.SaveRouteFunc(accountID, userID, route)
}
@ -423,7 +423,7 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R
}
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error {
func (am *MockAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
if am.DeleteRouteFunc != nil {
return am.DeleteRouteFunc(accountID, routeID, userID)
}

View File

@ -13,7 +13,7 @@ import (
)
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@ -40,7 +40,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
}
// checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID, routeID string, peerGroupIDs []string, prefix netip.Prefix) error {
func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix) error {
// routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefix(prefix)
@ -56,7 +56,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
}
if prefixRoute.Peer != "" {
seenPeers[prefixRoute.ID] = true
seenPeers[string(prefixRoute.ID)] = true
}
for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true
@ -114,7 +114,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
}
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@ -131,7 +131,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
}
var newRoute route.Route
newRoute.ID = xid.New().String()
newRoute.ID = route.ID(xid.New().String())
prefixType, newPrefix, err := route.ParseNetwork(network)
if err != nil {
@ -154,7 +154,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" {
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
@ -175,7 +175,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
newRoute.Groups = groups
if account.Routes == nil {
account.Routes = make(map[string]*route.Route)
account.Routes = make(map[route.ID]*route.Route)
}
account.Routes[newRoute.ID] = &newRoute
@ -187,7 +187,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
am.updateAccountPeers(account)
am.StoreEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta())
am.StoreEvent(userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
return &newRoute, nil
}
@ -209,7 +209,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" {
if utf8.RuneCountInString(string(routeToSave.NetID)) > route.MaxNetIDChar || routeToSave.NetID == "" {
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
@ -248,13 +248,13 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
am.updateAccountPeers(account)
am.StoreEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta())
am.StoreEvent(userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@ -274,7 +274,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
return err
}
am.StoreEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta())
am.StoreEvent(userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
am.updateAccountPeers(account)
@ -310,8 +310,8 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
func toProtocolRoute(route *route.Route) *proto.Route {
return &proto.Route{
ID: route.ID,
NetID: route.NetID,
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,

View File

@ -40,7 +40,7 @@ const (
func TestCreateRoute(t *testing.T) {
type input struct {
network string
netID string
netID route.NetID
peerKey string
peerGroupIDs []string
description string
@ -382,8 +382,8 @@ func TestSaveRoute(t *testing.T) {
invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34")
validMetric := 1000
invalidMetric := 99999
validNetID := "12345678901234567890qw"
invalidNetID := "12345678901234567890qwertyuiopqwertyuiop1"
validNetID := route.NetID("12345678901234567890qw")
invalidNetID := route.NetID("12345678901234567890qwertyuiopqwertyuiop1")
validGroupHA1 := routeGroupHA1
validGroupHA2 := routeGroupHA2

View File

@ -451,7 +451,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
}
account.GroupsG = nil
account.Routes = make(map[string]*route.Route, len(account.RoutesG))
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
}

View File

@ -2,8 +2,6 @@ package server
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"math/rand"
"net"
"net/netip"
@ -12,6 +10,9 @@ import (
"testing"
"time"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -75,9 +76,9 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
}
account.Users[user.Id] = user
route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n),
ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n),
NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,