Refactor groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-08 18:39:36 +03:00
parent 106fc75936
commit 0a70e4c5d4
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
5 changed files with 373 additions and 177 deletions

View File

@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return err return err
} }
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users") return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
return status.NewAdminPermissionError()
} }
return nil return nil
@ -49,8 +53,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@ -58,13 +61,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
@ -78,12 +80,19 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// Note: This function does not acquire the global lock. // Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. // It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
var eventsToStore []func() if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var (
eventsToStore []func()
groupsToSave []*nbgroup.Group
)
for _, newGroup := range newGroups { for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
@ -91,7 +100,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
} }
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name) existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil { if err != nil {
s, ok := status.FromError(err) s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound { if !ok || s.ErrorType != status.NotFound {
@ -109,15 +118,15 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
} }
for _, peerID := range newGroup.Peers { for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil { if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
} }
oldGroup := account.Groups[newGroup.ID] newGroup.AccountID = accountID
account.Groups[newGroup.ID] = newGroup groupsToSave = append(groupsToSave, newGroup)
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) events := am.prepareGroupEvents(ctx, userID, accountID, newGroup)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
} }
@ -126,30 +135,45 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
newGroupIDs = append(newGroupIDs, newGroup.ID) newGroupIDs = append(newGroupIDs, newGroup.ID)
} }
account.Network.IncSerial() updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs)
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err return err
} }
if areGroupChangesAffectPeers(account, newGroupIDs) { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
am.updateAccountPeers(ctx, account) if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil {
return fmt.Errorf("failed to save groups: %w", err)
}
return nil
})
if err != nil {
return err
} }
for _, storeEvent := range eventsToStore { for _, storeEvent := range eventsToStore {
storeEvent() storeEvent()
} }
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil return nil
} }
// prepareGroupEvents prepares a list of event functions to be stored. // prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func() var eventsToStore []func()
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
if oldGroup != nil { oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers) addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else { } else {
@ -159,12 +183,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}) })
} }
for _, p := range addedPeers { for _, peerID := range addedPeers {
peer := account.Peers[p] peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
@ -175,12 +200,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}) })
} }
for _, p := range removedPeers { for _, peerID := range removedPeers {
peer := account.Peers[p] peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
@ -210,119 +236,108 @@ func difference(a, b []string) []string {
} }
// DeleteGroup object of the peers. // DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if user.AccountID != accountID {
if !ok { return status.NewUserNotPartOfAccountError()
return nil
} }
allGroup, err := account.GetGroupAll() group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil { if err != nil {
return err return err
} }
if allGroup.ID == groupID { if group.Name == "All" {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
} }
if err = validateDeleteGroup(account, group, userId); err != nil { if err = am.validateDeleteGroup(ctx, group, userID); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
return nil return nil
} }
// DeleteGroups deletes groups from an account. // DeleteGroups deletes groups from an account.
// Note: This function does not acquire the global lock. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
var allErrors error if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var (
allErrors error
groupIDsToDelete []string
deletedGroups []*nbgroup.Group
)
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
group, ok := account.Groups[groupID] group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if !ok { if err != nil {
continue continue
} }
if err := validateDeleteGroup(account, group, userId); err != nil { if err := am.validateDeleteGroup(ctx, group, userID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue continue
} }
delete(account.Groups, groupID) groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group) deletedGroups = append(deletedGroups, group)
} }
account.Network.IncSerial() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err return err
} }
for _, g := range deletedGroups { if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) return fmt.Errorf("failed to delete group: %w", err)
}
return nil
})
if err != nil {
return err
}
for _, group := range deletedGroups {
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
} }
return allErrors return allErrors
} }
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true add := true
for _, itemID := range group.Peers { for _, itemID := range group.Peers {
if itemID == peerID { if itemID == peerID {
@ -334,13 +349,27 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
group.Peers = append(group.Peers, peerID) group.Peers = append(group.Peers, peerID)
} }
account.Network.IncSerial() updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID})
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err return err
} }
if areGroupChangesAffectPeers(account, []string{group.ID}) { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
am.updateAccountPeers(ctx, account) if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
})
if err != nil {
return err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
@ -348,41 +377,55 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] updated := false
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
for i, itemID := range group.Peers { for i, itemID := range group.Peers {
if itemID == peerID { if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil { updated = true
return err break
}
} }
} }
if areGroupChangesAffectPeers(account, []string{group.ID}) { if !updated {
am.updateAccountPeers(ctx, account) return nil
}
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID})
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
})
if err != nil {
return err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration { if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID] executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if executingUser == nil { if err != nil {
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
@ -390,32 +433,42 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
} }
} }
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name} return &GroupLinkError{"name server groups", linkedDns.Name}
} }
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name} return &GroupLinkError{"policy", linkedPolicy.Name}
} }
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name} return &GroupLinkError{"setup key", linkedSetupKey.Name}
} }
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id} return &GroupLinkError{"user", linkedUser.Id}
} }
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
if account.Settings.Extra != nil { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { if err != nil {
return err
}
if settings.Extra != nil {
if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name} return &GroupLinkError{"integrated validator", group.Name}
} }
} }
@ -423,6 +476,121 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
return nil return nil
} }
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) {
routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) {
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
return true, policy
}
}
}
return false, nil
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return true, dns
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
}
}
return false, nil
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) {
users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
}
}
return false, nil
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked {
return true, nil
}
if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked {
return true, nil
}
if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
for _, r := range routes { for _, r := range routes {
@ -457,26 +625,6 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
return false, nil return false, nil
} }
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
}
}
return false, nil
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
}
}
return false, nil
}
// anyGroupHasPeers checks if any of the given groups in the account have peers. // anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool { func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {

View File

@ -52,26 +52,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return am.Store.SaveAccount(ctx, a) return am.Store.SaveAccount(ctx, a)
} }
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
if len(groups) == 0 { if len(groupIDs) == 0 {
return true, nil return true, nil
} }
accountsGroups, err := am.ListGroups(ctx, accountId)
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
}
return nil
})
if err != nil { if err != nil {
return false, err return false, err
} }
for _, group := range groups {
var found bool
for _, accountGroup := range accountsGroups {
if accountGroup.ID == group {
found = true
break
}
}
if !found {
return false, nil
}
}
return true, nil return true, nil
} }

View File

@ -45,7 +45,6 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
} }
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil { if am.GroupAddPeerFunc != nil {

View File

@ -614,11 +614,11 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil return &user, nil
} }
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
startTime := time.Now() startTime := time.Now()
var users []*User var users []*User
result := s.db.Find(&users, accountIDCondition, accountID) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@ -1240,10 +1240,27 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
return nil return nil
} }
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { // GetPeerByID retrieves a peer by its ID and account ID.
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
var peer *nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&peer, accountAndIDQueryCondition, accountID, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peer from store")
}
return peer, nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
startTime := time.Now() startTime := time.Now()
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, context.Canceled) { if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime)) return status.NewStoreContextCanceledError(time.Since(startTime))
@ -1336,42 +1353,82 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
} }
// GetGroupByID retrieves a group by ID and account ID. // GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) var group *nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
}
log.WithContext(ctx).Errorf("failed to get group from store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get group from store")
}
return group, nil
} }
// GetGroupByName retrieves a group by name and account ID. // GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
var group nbgroup.Group var group nbgroup.Group
// TODO: This fix is accepted for now, but if we need to handle this more frequently // TODO: This fix is accepted for now, but if we need to handle this more frequently
// we may need to reconsider changing the types. // we may need to reconsider changing the types.
query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
if s.storeEngine == PostgresStoreEngine { if s.storeEngine == PostgresStoreEngine {
query = query.Order("json_array_length(peers::json) DESC") query = query.Order("json_array_length(peers::json) DESC")
} else { } else {
query = query.Order("json_array_length(peers) DESC") query = query.Order("json_array_length(peers) DESC")
} }
result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found") return nil, status.Errorf(status.NotFound, "group not found")
} }
return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
} }
return &group, nil return &group, nil
} }
// SaveGroup saves a group to the store. // SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil { if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save group to store")
} }
return nil return nil
} }
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete group from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "group not found")
}
return nil
}
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)

View File

@ -62,7 +62,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
@ -75,6 +75,8 @@ type Store interface {
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@ -89,6 +91,7 @@ type Store interface {
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error
@ -107,7 +110,7 @@ type Store interface {
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string GetInstallationID() string