refactor groups methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-10-01 16:32:31 +03:00
parent f9ed25f8b1
commit 78e238646c
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
4 changed files with 243 additions and 107 deletions

View File

@ -963,7 +963,7 @@ func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented") return status.Errorf(status.Internal, "SaveUsers is not implemented")
} }
func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { func (s *FileStore) SaveGroups(_ context.Context, _ LockingStrength, _ []*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented") return status.Errorf(status.Internal, "SaveGroups is not implemented")
} }
@ -1112,3 +1112,18 @@ func (s *FileStore) SaveDNSSettings(_ context.Context, _ LockingStrength, _ stri
func (s *FileStore) SaveAccountSettings(_ context.Context, _ LockingStrength, _ string, _ *Settings) error { func (s *FileStore) SaveAccountSettings(_ context.Context, _ LockingStrength, _ string, _ *Settings) error {
return status.Errorf(status.Internal, "SaveAccountSettings is not implemented") return status.Errorf(status.Internal, "SaveAccountSettings is not implemented")
} }
func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroup is not implemented")
}
func (s *FileStore) DeleteGroup(_ context.Context, _ LockingStrength, _, _ string) error {
return status.Errorf(status.Internal, "DeleteGroup is not implemented")
}
func (s *FileStore) DeleteGroups(_ context.Context, _ LockingStrength, _ []string, _ string) error {
return status.Errorf(status.Internal, "DeleteGroups is not implemented")
}
func (s *FileStore) GetAccountUsers(_ context.Context, _ LockingStrength, _ string) ([]*User, error) {
return nil, status.Errorf(status.Internal, "GetAccountUsers is not implemented")
}

View File

@ -69,21 +69,24 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
// SaveGroup object of the peers // SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
} }
// SaveGroups adds new groups to the account. // SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
} }
var eventsToStore []func() if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to create group")
}
var (
eventsToStore []func()
groupsToSave []*nbgroup.Group
)
for _, newGroup := range newGroups { for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
@ -91,7 +94,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
} }
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name) existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, newGroup.Name, accountID)
if err != nil { if err != nil {
s, ok := status.FromError(err) s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound { if !ok || s.ErrorType != status.NotFound {
@ -109,40 +112,54 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
} }
for _, peerID := range newGroup.Peers { for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil { if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID); err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
} }
oldGroup := account.Groups[newGroup.ID] newGroup.AccountID = accountID
account.Groups[newGroup.ID] = newGroup groupsToSave = append(groupsToSave, newGroup)
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) events := am.prepareGroupEvents(ctx, userID, accountID, newGroup)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
} }
account.Network.IncSerial() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err return fmt.Errorf("failed to increment network serial: %w", err)
} }
am.updateAccountPeers(ctx, account) if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil {
return fmt.Errorf("failed to save groups: %w", err)
}
return nil
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore { for _, storeEvent := range eventsToStore {
storeEvent() storeEvent()
} }
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
return nil return nil
} }
// prepareGroupEvents prepares a list of event functions to be stored. // prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func() var eventsToStore []func()
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
if oldGroup != nil { oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroup.ID, accountID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers) addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else { } else {
@ -152,12 +169,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}) })
} }
for _, p := range addedPeers { for _, peerID := range addedPeers {
peer := account.Peers[p] peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID)
if peer == nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
@ -168,12 +186,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}) })
} }
for _, p := range removedPeers { for _, peerID := range removedPeers {
peer := account.Peers[p] peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID)
if peer == nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue continue
} }
peerCopy := peer // copy to avoid closure issues peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
@ -203,85 +222,108 @@ func difference(a, b []string) []string {
} }
// DeleteGroup object of the peers. // DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] if user.AccountID != accountID {
if !ok { return status.Errorf(status.PermissionDenied, "no permission to delete group")
return nil
} }
allGroup, err := account.GetGroupAll() group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
if err != nil { if err != nil {
return err return err
} }
if allGroup.ID == groupID { if group.Name == "All" {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
} }
if err = validateDeleteGroup(account, group, userId); err != nil { if err = am.validateDeleteGroup(ctx, group, userID); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err return err
} }
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, groupID, accountID); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return nil return nil
} }
// DeleteGroups deletes groups from an account. // DeleteGroups deletes groups from an account.
// Note: This function does not acquire the global lock. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
var allErrors error if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to delete groups")
}
var (
allErrors error
groupIDsToDelete []string
deletedGroups []*nbgroup.Group
)
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
group, ok := account.Groups[groupID] group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
if !ok { if err != nil {
continue continue
} }
if err := validateDeleteGroup(account, group, userId); err != nil { if err := am.validateDeleteGroup(ctx, group, userID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue continue
} }
delete(account.Groups, groupID) groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group) deletedGroups = append(deletedGroups, group)
} }
account.Network.IncSerial() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, groupIDsToDelete, accountID); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
})
if err != nil {
return err return err
} }
for _, g := range deletedGroups { for _, group := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
} }
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return allErrors return allErrors
@ -371,11 +413,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return nil return nil
} }
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration { if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID] executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if executingUser == nil { if err != nil {
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
} }
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
@ -383,32 +425,42 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
} }
} }
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.ID, group.AccountID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.ID, group.AccountID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name} return &GroupLinkError{"name server groups", linkedDns.Name}
} }
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.ID, group.AccountID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name} return &GroupLinkError{"policy", linkedPolicy.Name}
} }
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.ID, group.AccountID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name} return &GroupLinkError{"setup key", linkedSetupKey.Name}
} }
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.ID, group.AccountID); isLinked {
return &GroupLinkError{"user", linkedUser.Id} return &GroupLinkError{"user", linkedUser.Id}
} }
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
if account.Settings.Extra != nil { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { if err != nil {
return err
}
if settings.Extra != nil {
if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name} return &GroupLinkError{"integrated validator", group.Name}
} }
} }
@ -417,17 +469,30 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
} }
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, groupID string, accountID string) (bool, *route.Route) {
routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes { for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r return true, r
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, groupID string, accountID string) (bool, *Policy) {
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies { for _, policy := range policies {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@ -439,7 +504,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
} }
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, groupID string, accountID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups { for _, dns := range nameServerGroups {
for _, g := range dns.Groups { for _, g := range dns.Groups {
if g == groupID { if g == groupID {
@ -447,11 +518,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
} }
} }
} }
return false, nil return false, nil
} }
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, groupID string, accountID string) (bool, *SetupKey) {
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys { for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) { if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey return true, setupKey
@ -461,7 +539,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
} }
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, groupID string, accountID string) (bool, *User) {
users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users { for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) { if slices.Contains(user.AutoGroups, groupID) {
return true, user return true, user

View File

@ -378,17 +378,6 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Create(&usersToSave).Error Create(&usersToSave).Error
} }
// SaveGroups saves the given list of groups to the database.
// It updates existing groups if a conflict occurs.
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
groupsToSave := make([]nbgroup.Group, 0, len(groups))
for _, group := range groups {
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
}
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore // DeleteHashedPAT2TokenIDIndex is noop in SqlStore
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
return nil return nil
@ -500,6 +489,11 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil return &user, nil
} }
// GetAccountUsers returns all users associated with the account.
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
return getRecords[User](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID) result := s.db.Find(&groups, accountIDCondition, accountID)
@ -1135,9 +1129,38 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil return &group, nil
} }
// SaveGroup saves a group to the database.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
return saveRecord[nbgroup.Group](s.db.WithContext(ctx), lockStrength, group)
}
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
}
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) error {
return deleteRecordByID[nbgroup.Group](s.db.WithContext(ctx), lockStrength, groupID, accountID)
}
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, groupIDs []string, accountID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}).
Where("account_id AND id IN ?", accountID, groupIDs).Delete(&nbgroup.Group{})
if result.Error != nil {
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) return getRecords[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
} }
// GetPolicyByID retrieves a policy by its ID and account ID. // GetPolicyByID retrieves a policy by its ID and account ID.
@ -1159,7 +1182,7 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
// GetAccountPostureChecks retrieves posture checks for an account. // GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) return getRecords[posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
} }
// GetPostureChecksByID retrieves posture checks by their ID and account ID. // GetPostureChecksByID retrieves posture checks by their ID and account ID.
@ -1188,7 +1211,7 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) return getRecords[route.Route](s.db.WithContext(ctx), lockStrength, accountID)
} }
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
@ -1209,7 +1232,7 @@ func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) return getRecords[SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
} }
// GetSetupKeyByID retrieves a setup key by its ID and account ID. // GetSetupKeyByID retrieves a setup key by its ID and account ID.
@ -1231,7 +1254,7 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
// GetAccountNameServerGroups retrieves name server groups for an account. // GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) return getRecords[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
} }
// GetNameServerGroupByID retrieves a name server group by its ID and account ID. // GetNameServerGroupByID retrieves a name server group by its ID and account ID.
@ -1277,13 +1300,13 @@ func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength,
// GetAccountPeers retrieves peers for an account. // GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID) return getRecords[nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID)
} }
// GetAccountPeersWithExpiration retrieves a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user. // GetAccountPeersWithExpiration retrieves a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user.
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
db := s.db.WithContext(ctx).Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true) db := s.db.WithContext(ctx).Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true)
return getRecords[*nbpeer.Peer](db, lockStrength, accountID) return getRecords[nbpeer.Peer](db, lockStrength, accountID)
} }
// GetPeerByID retrieves a peer by its ID and account ID. // GetPeerByID retrieves a peer by its ID and account ID.
@ -1292,14 +1315,12 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
} }
// getRecords retrieves records from the database based on the account ID. // getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]*T, error) {
var record []T var record []*T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID) result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".") recordType := getRecordType(record)
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
} }
@ -1313,8 +1334,7 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&record, accountAndIDQueryCondition, accountID, recordID) First(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".") recordType := getRecordType(record)
recordType := parts[len(parts)-1]
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "%s not found", recordType) return nil, status.Errorf(status.NotFound, "%s not found", recordType)
@ -1324,15 +1344,23 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
return &record, nil return &record, nil
} }
// saveRecord saves a record to the database.
func saveRecord[T any](db *gorm.DB, lockStrength LockingStrength, record *T) error {
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(record)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save %s to store: %v", getRecordType(record), result.Error)
}
return nil
}
// deleteRecordByID deletes a record by its ID and account ID from the database. // deleteRecordByID deletes a record by its ID and account ID from the database.
func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error { func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error {
var record T var record T
parts := strings.Split(fmt.Sprintf("%T", record), ".") recordType := getRecordType(record)
recordType := parts[len(parts)-1]
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&record, accountAndIDQueryCondition, accountID, recordID)
Delete(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err) return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err)
} }
@ -1343,3 +1371,8 @@ func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID
return nil return nil
} }
func getRecordType(record any) string {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
return parts[len(parts)-1]
}

View File

@ -62,6 +62,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
@ -75,7 +76,10 @@ type Store interface {
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, groupIDs []string, accountID string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)