From 0a70e4c5d45292223c78427984fb470aaf0a9a40 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 8 Nov 2024 18:39:36 +0300 Subject: [PATCH] Refactor groups to use store methods Signed-off-by: bcmmbaga --- management/server/group.go | 426 ++++++++++++------ management/server/integrated_validator.go | 27 +- management/server/mock_server/account_mock.go | 9 - management/server/sql_store.go | 81 +++- management/server/store.go | 7 +- 5 files changed, 373 insertions(+), 177 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index b2ec88cc0..da4c0fb94 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return err } - if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "groups are blocked for users") + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() && settings.RegularUsersViewBlocked { + return status.NewAdminPermissionError() } return nil @@ -49,8 +53,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - - return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account @@ -58,13 +61,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) + return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers @@ -78,12 +80,19 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - var eventsToStore []func() + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + var ( + eventsToStore []func() + groupsToSave []*nbgroup.Group + ) for _, newGroup := range newGroups { if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { @@ -91,7 +100,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := account.FindGroupByName(newGroup.Name) + existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) if err != nil { s, ok := status.FromError(err) if !ok || s.ErrorType != status.NotFound { @@ -109,15 +118,15 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } for _, peerID := range newGroup.Peers { - if account.Peers[peerID] == nil { + if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } } - oldGroup := account.Groups[newGroup.ID] - account.Groups[newGroup.ID] = newGroup + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) - events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) + events := am.prepareGroupEvents(ctx, userID, accountID, newGroup) eventsToStore = append(eventsToStore, events...) } @@ -126,30 +135,45 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user newGroupIDs = append(newGroupIDs, newGroup.ID) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs) + if err != nil { return err } - if areGroupChangesAffectPeers(account, newGroupIDs) { - am.updateAccountPeers(ctx, account) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil { + return fmt.Errorf("failed to save groups: %w", err) + } + return nil + }) + if err != nil { + return err } for _, storeEvent := range eventsToStore { storeEvent() } + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - if oldGroup != nil { + oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) } else { @@ -159,12 +183,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range addedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range addedPeers { + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) continue } + peerCopy := peer // copy to avoid closure issues eventsToStore = append(eventsToStore, func() { am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, @@ -175,12 +200,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range removedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range removedPeers { + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) continue } + peerCopy := peer // copy to avoid closure issues eventsToStore = append(eventsToStore, func() { am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, @@ -210,119 +236,108 @@ func difference(a, b []string) []string { } // DeleteGroup object of the peers. -func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return nil + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - allGroup, err := account.GetGroupAll() + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - if allGroup.ID == groupID { + if group.Name == "All" { return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") } - if err = validateDeleteGroup(account, group, userId); err != nil { - return err - } - delete(account.Groups, groupID) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + if err = am.validateDeleteGroup(ctx, group, userID); err != nil { return err } - am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil { + return fmt.Errorf("failed to delete group: %w", err) + } + return nil + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta()) return nil } // DeleteGroups deletes groups from an account. -// Note: This function does not acquire the global lock. -// It is the caller's responsibility to ensure proper locking is in place before invoking this method. -// -// If an error occurs while deleting a group, the function skips it and continues deleting other groups. -// Errors are collected and returned at the end. -func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - var allErrors error + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + var ( + allErrors error + groupIDsToDelete []string + deletedGroups []*nbgroup.Group + ) - deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) for _, groupID := range groupIDs { - group, ok := account.Groups[groupID] - if !ok { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { continue } - if err := validateDeleteGroup(account, group, userId); err != nil { + if err := am.validateDeleteGroup(ctx, group, userID); err != nil { allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) continue } - delete(account.Groups, groupID) + groupIDsToDelete = append(groupIDsToDelete, groupID) deletedGroups = append(deletedGroups, group) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil { + return fmt.Errorf("failed to delete group: %w", err) + } + return nil + }) + if err != nil { return err } - for _, g := range deletedGroups { - am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) + for _, group := range deletedGroups { + am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta()) } return allErrors } -// ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) - } - - return groups, nil -} - // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - add := true for _, itemID := range group.Peers { if itemID == peerID { @@ -334,13 +349,27 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr group.Peers = append(group.Peers, peerID) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) + if err != nil { return err } - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { + return fmt.Errorf("failed to save group: %w", err) + } + return nil + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil @@ -348,41 +377,55 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - - account.Network.IncSerial() + updated := false for i, itemID := range group.Peers { if itemID == peerID { group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - if err := am.Store.SaveAccount(ctx, account); err != nil { - return err - } + updated = true + break } } - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + if !updated { + return nil + } + + updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) + if err != nil { + return err + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { + return fmt.Errorf("failed to save group: %w", err) + } + return nil + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil } -func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { +func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser := account.Users[userID] - if executingUser == nil { + executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { return status.Errorf(status.NotFound, "user not found") } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { @@ -390,32 +433,42 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) } } - if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { + if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { + if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { + if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { + if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { + if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } - if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { + dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) { return &GroupLinkError{"disabled DNS management groups", group.Name} } - if account.Settings.Extra != nil { - if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if settings.Extra != nil { + if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { return &GroupLinkError{"integrated validator", group.Name} } } @@ -423,6 +476,121 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) return nil } +// isGroupLinkedToRoute checks if a group is linked to any route in the account. +func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) { + routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) + return false, nil + } + + for _, r := range routes { + if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { + return true, r + } + } + + return false, nil +} + +// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. +func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) { + policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) + return false, nil + } + + for _, policy := range policies { + for _, rule := range policy.Rules { + if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { + return true, policy + } + } + } + return false, nil +} + +// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. +func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) + return false, nil + } + + for _, dns := range nameServerGroups { + for _, g := range dns.Groups { + if g == groupID { + return true, dns + } + } + } + + return false, nil +} + +// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. +func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) { + setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) + return false, nil + } + + for _, setupKey := range setupKeys { + if slices.Contains(setupKey.AutoGroups, groupID) { + return true, setupKey + } + } + return false, nil +} + +// isGroupLinkedToUser checks if a group is linked to any user in the account. +func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) { + users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) + return false, nil + } + + for _, user := range users { + if slices.Contains(user.AutoGroups, groupID) { + return true, user + } + } + return false, nil +} + +// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. +func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { + return false, nil + } + + dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err + } + + for _, groupID := range groupIDs { + if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { + return true, nil + } + if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked { + return true, nil + } + if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked { + return true, nil + } + if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked { + return true, nil + } + } + + return false, nil +} + // isGroupLinkedToRoute checks if a group is linked to any route in the account. func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { for _, r := range routes { @@ -457,26 +625,6 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou return false, nil } -// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { - for _, setupKey := range setupKeys { - if slices.Contains(setupKey.AutoGroups, groupID) { - return true, setupKey - } - } - return false, nil -} - -// isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { - for _, user := range users { - if slices.Contains(user.AutoGroups, groupID) { - return true, user - } - } - return false, nil -} - // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 99e6b204c..0c70b702a 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con return am.Store.SaveAccount(ctx, a) } -func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { - if len(groups) == 0 { +func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { return true, nil } - accountsGroups, err := am.ListGroups(ctx, accountId) - if err != nil { - return false, err - } - for _, group := range groups { - var found bool - for _, accountGroup := range accountsGroups { - if accountGroup.ID == group { - found = true - break + + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err } } - if !found { - return false, nil - } + return nil + }) + if err != nil { + return false, err } return true, nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d7139bb2a..aa6a47b15 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -45,7 +45,6 @@ type MockAccountManager struct { SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error @@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") } -// ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { - if am.ListGroupsFunc != nil { - return am.ListGroupsFunc(ctx, accountID) - } - return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented") -} - // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupAddPeerFunc != nil { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index a11370e4f..506142453 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -614,11 +614,11 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { startTime := time.Now() var users []*User - result := s.db.Find(&users, accountIDCondition, accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -1240,10 +1240,27 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro return nil } -func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { +// GetPeerByID retrieves a peer by its ID and account ID. +func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { + var peer *nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&peer, accountAndIDQueryCondition, accountID, peerID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "peer not found") + } + log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peer from store") + } + + return peer, nil +} + +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { startTime := time.Now() - result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { if errors.Is(result.Error, context.Canceled) { return status.NewStoreContextCanceledError(time.Since(startTime)) @@ -1336,42 +1353,82 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { - return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { + var group *nbgroup.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "group not found") + } + log.WithContext(ctx).Errorf("failed to get group from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get group from store") + } + + return group, nil } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { var group nbgroup.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. - query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) + query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) if s.storeEngine == PostgresStoreEngine { query = query.Order("json_array_length(peers::json) DESC") } else { query = query.Order("json_array_length(peers) DESC") } - result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) + result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") } - return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get group by name from store") } return &group, nil } // SaveGroup saves a group to the store. func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { - return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save group to store") } return nil } +// DeleteGroup deletes a group from the database. +func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete group from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "group not found") + } + + return nil +} + +// DeleteGroups deletes groups from the database. +func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { + result := s.db.Clauses(clause.Locking{Strength: string(strength)}). + Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + } + + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) diff --git a/management/server/store.go b/management/server/store.go index 73c9ef6a6..cb3c533dd 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -62,7 +62,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error @@ -75,6 +75,8 @@ type Store interface { GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error + DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) @@ -89,6 +91,7 @@ type Store interface { AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) + GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error @@ -107,7 +110,7 @@ type Store interface { GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementNetworkSerial(ctx context.Context, accountId string) error + IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetInstallationID() string