[management] Refactor group to use store methods (#2867)

* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor account peers update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor GetGroupByID and add NewGroupNotFoundError

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add AddPeer and RemovePeer methods to Group struct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preserve store engine in SqlStore transactions

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run groups ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Change setup key log level to debug for missing group

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve modified peers once for group events

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locking and merge group deletion methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Bethuel Mmbaga 2024-11-15 20:09:32 +03:00 committed by GitHub
parent d9b691b8a5
commit 12f442439a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 878 additions and 335 deletions

View File

@ -110,7 +110,6 @@ type AccountManager interface {
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
@ -1435,7 +1434,7 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
@ -2083,7 +2082,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
}
@ -2101,7 +2100,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
@ -2114,7 +2113,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
@ -2127,14 +2126,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
if settings.GroupsPropagationEnabled {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
if err != nil {
return status.NewGetAccountError(err)
return err
}
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
if err != nil {
return err
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
}
@ -2401,12 +2405,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
return
}
am.updateAccountPeers(ctx, updatedAccount)
am.updateAccountPeers(ctx, accountID)
}
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {

View File

@ -1413,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{
err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}
})
require.NoError(t, err, "failed to save group")
policy := Policy{
Enabled: true,
@ -1460,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return
}
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
t.Errorf("delete group: %v", err)
return
}
@ -2714,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
@ -2734,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})

View File

@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
}
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@ -33,8 +33,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
return nil
@ -45,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
}
// GetAllGroups returns all groups in an account
@ -54,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
}
// SaveGroup object of the peers
@ -73,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID)
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var eventsToStore []func()
var groupsToSave []*nbgroup.Group
var updateAccountPeers bool
for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
groupIDs = append(groupIDs, newGroup.ID)
newGroup.ID = xid.New().String()
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if err != nil {
return err
}
oldGroup := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
eventsToStore = append(eventsToStore, events...)
}
newGroupIDs := make([]string, 0, len(newGroups))
for _, newGroup := range newGroups {
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
})
if err != nil {
return err
}
if areGroupChangesAffectPeers(account, newGroupIDs) {
am.updateAccountPeers(ctx, account)
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
if oldGroup != nil {
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
@ -155,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
})
}
for _, p := range addedPeers {
peer := account.Peers[p]
if peer == nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil
}
for _, peerID := range addedPeers {
peer, ok := peers[peerID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
meta := map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
})
}
for _, p := range removedPeers {
peer := account.Peers[p]
if peer == nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
for _, peerID := range removedPeers {
peer, ok := peers[peerID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
meta := map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
})
}
@ -206,42 +210,10 @@ func difference(a, b []string) []string {
}
// DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return nil
}
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
if allGroup.ID == groupID {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
return nil
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
}
// DeleteGroups deletes groups from an account.
@ -250,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var allErrors error
var groupIDsToDelete []string
var deletedGroups []*nbgroup.Group
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs {
group, ok := account.Groups[groupID]
if !ok {
continue
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
}
if err := validateDeleteGroup(account, group, userId); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
delete(account.Groups, groupID)
deletedGroups = append(deletedGroups, group)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
})
if err != nil {
return err
}
for _, g := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
for _, group := range deletedGroups {
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
}
return allErrors
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
@ -347,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
}
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
// validateNewGroup validates the new group for existence and required fields.
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err
}
}
// Prevent duplicate groups for API-issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
return nil
}
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
return checkGroupLinkedToSettings(ctx, transaction, group)
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
if account.Settings.Extra != nil {
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@ -442,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
@ -450,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
@ -464,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
@ -473,6 +537,35 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
return false, nil
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
@ -482,22 +575,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
}
return false
}
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
return true
}
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
}
}
return false
}

View File

@ -54,3 +54,30 @@ func (g *Group) HasPeers() bool {
func (g *Group) IsGroupAll() bool {
return g.Name == "All"
}
// AddPeer adds peerID to Peers if not present, returning true if added.
func (g *Group) AddPeer(peerID string) bool {
if peerID == "" {
return false
}
for _, itemID := range g.Peers {
if itemID == peerID {
return false
}
}
g.Peers = append(g.Peers, peerID)
return true
}
// RemovePeer removes peerID from Peers if present, returning true if removed.
func (g *Group) RemovePeer(peerID string) bool {
for i, itemID := range g.Peers {
if itemID == peerID {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
return true
}
}
return false
}

View File

@ -0,0 +1,90 @@
package group
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddPeer(t *testing.T) {
t.Run("add new peer to empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to non-empty slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add duplicate peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer1"
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("add empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}
func TestRemovePeer(t *testing.T) {
t.Run("remove existing peer from slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
peerID := "peer2"
assert.True(t, group.RemovePeer(peerID))
assert.NotContains(t, group.Peers, peerID)
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
})
t.Run("remove peer from nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Nil(t, group.Peers)
})
t.Run("remove non-existent peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from single-item slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1"}}
peerID := "peer1"
assert.True(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
assert.NotContains(t, group.Peers, peerID)
})
t.Run("remove empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}

View File

@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
{
name: "delete non-existent group",
groupIDs: []string{"non-existent-group"},
expectedDeleted: []string{"non-existent-group"},
expectedReasons: []string{"group: non-existent-group not found"},
},
{
name: "delete multiple groups with mixed results",

View File

@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return am.Store.SaveAccount(ctx, a)
}
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
if len(groups) == 0 {
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return true, nil
}
accountsGroups, err := am.ListGroups(ctx, accountId)
if err != nil {
return false, err
}
for _, group := range groups {
var found bool
for _, accountGroup := range accountsGroups {
if accountGroup.ID == group {
found = true
break
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
}
if !found {
return false, nil
}
return nil
})
if err != nil {
return false, err
}
return true, nil

View File

@ -45,7 +45,6 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
}
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil {

View File

@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
}
if anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
}
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
}
if anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())

View File

@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
return nil
@ -271,7 +271,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
if peerLabelUpdated || requiresPeerUpdates {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return peer, nil
@ -335,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
updateAccountPeers := isPeerInActiveGroup(account, peerID)
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID)
if err != nil {
return err
}
err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil {
@ -348,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@ -555,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
@ -598,10 +601,15 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
}
groupsToAdd = append(groupsToAdd, allGroup.ID)
if areGroupChangesAffectPeers(account, groupsToAdd) {
am.updateAccountPeers(ctx, account)
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd)
if err != nil {
return nil, nil, nil, err
}
if newGroupsAffectsPeers {
am.updateAccountPeers(ctx, accountID)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
@ -666,7 +674,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
if sync.UpdateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
}
@ -685,7 +693,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
if isStatusChanged {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
validPeersMap, err := am.GetValidatedPeers(account)
@ -816,7 +824,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
@ -979,7 +987,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
start := time.Now()
defer func() {
if am.metrics != nil {
@ -987,6 +995,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
}
}()
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
return
}
peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account)
@ -1033,12 +1046,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(account *Account, peerID string) bool {
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
}
return areGroupChangesAffectPeers(account, peerGroupIDs)
return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs)
}

View File

@ -877,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
start := time.Now()
for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account)
manager.updateAccountPeers(ctx, account.Id)
}
duration := time.Since(start)

View File

@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
if anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
}
if isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.ListGroups(context.Background(), account.Id)
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups {

View File

@ -453,14 +453,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err)
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil
}
for _, g := range removedGroups {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err)
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g)
continue
}
@ -473,7 +473,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
for _, g := range addedGroups {
group, ok := groups[g]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err)
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g)
continue
}

View File

@ -33,12 +33,13 @@ import (
)
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
)
// SqlStore represents an account storage backed by a Sql DB persisted to disk
@ -555,9 +556,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
var users []*User
result := s.db.Find(&users, accountIDCondition, accountID)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@ -857,7 +858,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
}
return status.NewGetUserFromStoreError()
}
user.LastLogin = lastLogin
@ -1045,7 +1045,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
return status.NewGroupNotFoundError(groupID)
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
@ -1079,10 +1079,45 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
// GetPeerByID retrieves a peer by its ID and account ID.
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
var peer *nbpeer.Peer
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&peer, accountAndIDQueryCondition, accountID, peerID)
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peer from store")
}
return peer, nil
}
// GetPeersByIDs retrieves peers by their IDs and account ID.
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store")
}
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return peersMap, nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")
}
return nil
}
@ -1103,7 +1138,8 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
db: tx,
db: tx,
storeEngine: s.storeEngine,
}
}
@ -1155,12 +1191,22 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
}
// GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID)
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
var group *nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupID)
}
log.WithContext(ctx).Errorf("failed to get group from store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get group from store")
}
return group, nil
}
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
var group nbgroup.Group
// TODO: This fix is accepted for now, but if we need to handle this more frequently
@ -1172,12 +1218,13 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
query = query.Order("json_array_length(peers) DESC")
}
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
return nil, status.NewGroupNotFoundError(groupName)
}
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
}
return &group, nil
}
@ -1185,7 +1232,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// GetGroupsByIDs retrieves groups by their IDs and account ID.
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
@ -1203,11 +1250,40 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save group to store")
}
return nil
}
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete group from store")
}
if result.RowsAffected == 0 {
return status.NewGroupNotFoundError(groupID)
}
return nil
}
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete groups from store")
}
return nil
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)

View File

@ -14,11 +14,10 @@ import (
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
route2 "github.com/netbirdio/netbird/route"
@ -1181,7 +1180,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
if err != nil {
t.Fatal("failed to get group")
return err
@ -1201,7 +1200,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID)
users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, users, len(account.Users))
}
@ -1260,9 +1259,9 @@ func TestSqlite_GetGroupByName(t *testing.T) {
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
require.NoError(t, err)
require.Equal(t, "All", group.Name)
require.True(t, group.IsGroupAll())
}
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
require.Error(t, err)
}
func TestSqlStore_GetGroupsByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupIDs []string
expectedCount int
}{
{
name: "retrieve existing groups by existing IDs",
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
expectedCount: 2,
},
{
name: "empty group IDs list",
groupIDs: []string{},
expectedCount: 0,
},
{
name: "non-existing group IDs",
groupIDs: []string{"nonexistent1", "nonexistent2"},
expectedCount: 0,
},
{
name: "mixed existing and non-existing group IDs",
groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs)
require.NoError(t, err)
require.Len(t, groups, tt.expectedCount)
})
}
}
func TestSqlStore_SaveGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group := &nbgroup.Group{
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
}
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
require.NoError(t, err)
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
require.NoError(t, err)
require.Equal(t, savedGroup, group)
}
func TestSqlStore_SaveGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
groups := []*nbgroup.Group{
{
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
},
{
ID: "group-2",
AccountID: accountID,
Issued: "integration",
Peers: []string{"peer3", "peer4"},
},
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
require.NoError(t, err)
}
func TestSqlStore_DeleteGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupID string
expectError bool
}{
{
name: "delete existing group",
groupID: "cfefqs706sqkneg59g4g",
expectError: false,
},
{
name: "delete non-existing group",
groupID: "non-existing-group-id",
expectError: true,
},
{
name: "delete with empty group ID",
groupID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
} else {
require.NoError(t, err)
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID)
require.Error(t, err)
require.Nil(t, group)
}
})
}
}
func TestSqlStore_DeleteGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
groupIDs []string
expectError bool
}{
{
name: "delete multiple existing groups",
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
expectError: false,
},
{
name: "delete non-existing groups",
groupIDs: []string{"non-existing-id-1", "non-existing-id-2"},
expectError: false,
},
{
name: "delete with empty group IDs list",
groupIDs: []string{},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs)
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
for _, groupID := range tt.groupIDs {
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
require.Error(t, err)
require.Nil(t, group)
}
}
})
}
}
func TestSqlStore_GetPeerByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
peerID string
expectError bool
}{
{
name: "retrieve existing peer",
peerID: "cfefqs706sqkneg59g4g",
expectError: false,
},
{
name: "retrieve non-existing peer",
peerID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty peer ID",
peerID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, peer)
} else {
require.NoError(t, err)
require.NotNil(t, peer)
require.Equal(t, tt.peerID, peer.ID)
}
})
}
}
func TestSqlStore_GetPeersByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
peerIDs []string
expectedCount int
}{
{
name: "retrieve existing peers by existing IDs",
peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"},
expectedCount: 2,
},
{
name: "empty peer IDs list",
peerIDs: []string{},
expectedCount: 0,
},
{
name: "non-existing peer IDs",
peerIDs: []string{"nonexistent1", "nonexistent2"},
expectedCount: 0,
},
{
name: "mixed existing and non-existing peer IDs",
peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs)
require.NoError(t, err)
require.Len(t, peers, tt.expectedCount)
})
}
}

View File

@ -3,7 +3,6 @@ package status
import (
"errors"
"fmt"
"time"
)
const (
@ -126,11 +125,6 @@ func NewAdminPermissionError() error {
return Errorf(PermissionDenied, "admin role required to perform this action")
}
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
func NewStoreContextCanceledError(duration time.Duration) error {
return Errorf(Internal, "store access: context canceled after %v", duration)
}
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID")
@ -140,3 +134,8 @@ func NewInvalidKeyIDError() error {
func NewGetAccountError(err error) error {
return Errorf(Internal, "error getting account: %s", err)
}
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
func NewGroupNotFoundError(groupID string) error {
return Errorf(NotFound, "group: %s not found", groupID)
}

View File

@ -62,7 +62,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
@ -76,6 +76,8 @@ type Store interface {
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@ -90,6 +92,8 @@ type Store interface {
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
@ -108,7 +112,7 @@ type Store interface {
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string

View File

@ -494,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
return nil
@ -835,7 +835,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
for _, storeEvent := range eventsToStore {
@ -1132,7 +1132,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
return nil
}
@ -1240,7 +1240,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
}
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
for targetUserID, meta := range deletedUsersMeta {