Remove get account from groups ops

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-16 16:04:34 +03:00
parent 1123729c1c
commit d7c63d5c04
14 changed files with 87 additions and 86 deletions

View File

@ -53,7 +53,10 @@ const (
DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerLoginExpiration = 24 * time.Hour
DefaultPeerInactivityExpiration = 10 * time.Minute DefaultPeerInactivityExpiration = 10 * time.Minute
emptyUserID = "empty user ID in claims" emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
errGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
errGetAccountFmt = "failed to get account: %w"
errNetworkSerialIncrementFmt = "failed to increment network serial: %w"
) )
type userLoggedInOnce bool type userLoggedInOnce bool
@ -111,7 +114,6 @@ type AccountManager interface {
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
@ -1144,7 +1146,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
var oldSettings *Settings var oldSettings *Settings
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account settings: %w", err) return fmt.Errorf("failed to get account settings: %w", err)
} }
@ -1153,7 +1155,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return fmt.Errorf("failed to validate extra settings: %w", err) return fmt.Errorf("failed to validate extra settings: %w", err)
} }
if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil { if err = transaction.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil {
return fmt.Errorf("failed to update account settings: %w", err) return fmt.Errorf("failed to update account settings: %w", err)
} }
@ -2049,7 +2051,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user: %w", err) return fmt.Errorf("error getting user: %w", err)
} }
groups, err := am.Store.GetAccountGroups(ctx, accountID) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@ -2079,7 +2081,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, accountID) groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@ -2226,7 +2228,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errGettingDomainAccIDFmt, err)
return "", nil, err return "", nil, err
} }
@ -2240,7 +2242,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
// check again if the domain has a primary account because of simultaneous requests // check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errGettingDomainAccIDFmt, err)
return "", nil, err return "", nil, err
} }
@ -2271,7 +2273,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errGettingDomainAccIDFmt, err)
return "", err return "", err
} }

View File

@ -2612,7 +2612,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims) err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added") assert.Len(t, groups, 3, "new group3 should be added")

View File

@ -113,7 +113,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return err return err
} }
groups, err := am.Store.GetAccountGroups(ctx, accountID) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return err return err
} }
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { if !user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked {
return status.Errorf(status.PermissionDenied, "groups are blocked for users") return status.Errorf(status.PermissionDenied, "access to groups is blocked for users")
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
} }
return nil return nil
@ -59,7 +63,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, accountID) return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
@ -80,7 +84,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
} }
if user.AccountID != accountID { if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to create group") return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
} }
var ( var (
@ -126,7 +130,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf(errNetworkSerialIncrementFmt, err)
} }
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil { if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil {
@ -144,7 +148,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account: %w", err) return fmt.Errorf(errGetAccountFmt, err)
} }
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
@ -229,7 +233,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
} }
if user.AccountID != accountID { if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to delete group") return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
} }
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
@ -247,7 +251,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf(errNetworkSerialIncrementFmt, err)
} }
if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, groupID, accountID); err != nil { if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, groupID, accountID); err != nil {
@ -263,7 +267,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account: %w", err) return fmt.Errorf(errGetAccountFmt, err)
} }
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
@ -278,7 +282,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
} }
if user.AccountID != accountID { if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to delete groups") return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
} }
var ( var (
@ -304,7 +308,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf(errNetworkSerialIncrementFmt, err)
} }
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil { if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
@ -322,46 +326,20 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("error getting account: %w", err) return fmt.Errorf(errGetAccountFmt, err)
} }
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return allErrors return allErrors
} }
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group // GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true add := true
for _, itemID := range group.Peers { for _, itemID := range group.Peers {
if itemID == peerID { if itemID == peerID {
@ -373,11 +351,24 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
group.Peers = append(group.Peers, peerID) group.Peers = append(group.Peers, peerID)
} }
account.Network.IncSerial() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
})
if err != nil {
return err return err
} }
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return nil return nil
@ -385,29 +376,42 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
group, ok := account.Groups[groupID] updated := false
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
for i, itemID := range group.Peers { for i, itemID := range group.Peers {
if itemID == peerID { if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil { updated = true
return err break
}
} }
} }
if !updated {
return nil
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
})
if err != nil {
return err
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return nil return nil

View File

@ -56,13 +56,15 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
if len(groups) == 0 { if len(groups) == 0 {
return true, nil return true, nil
} }
accountsGroups, err := am.ListGroups(ctx, accountId)
accountGroups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountId)
if err != nil { if err != nil {
return false, err return false, err
} }
for _, group := range groups { for _, group := range groups {
var found bool var found bool
for _, accountGroup := range accountsGroups { for _, accountGroup := range accountGroups {
if accountGroup.ID == group { if accountGroup.ID == group {
found = true found = true
break break

View File

@ -45,7 +45,6 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
@ -347,14 +346,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
} }
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil { if am.GroupAddPeerFunc != nil {

View File

@ -216,7 +216,7 @@ func (am *DefaultAccountManager) validateNameServerGroup(ctx context.Context, ac
return err return err
} }
groups, err := am.Store.GetAccountGroups(ctx, accountID) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -743,7 +743,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
} }
groups, err := am.Store.GetAccountGroups(ctx, accountID) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }

View File

@ -1090,7 +1090,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.ListGroups(context.Background(), account.Id) groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups { for _, group := range groups {

View File

@ -225,7 +225,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
keyDuration = expiresIn keyDuration = expiresIn
} }
groups, err := am.Store.GetAccountGroups(ctx, accountID) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -278,7 +278,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can update setup keys") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can update setup keys")
} }
groups, err := am.Store.GetAccountGroups(ctx, accountID) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -555,9 +555,9 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
return users, nil return users, nil
} }
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@ -841,7 +841,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
// SaveAccountSettings stores the account settings in DB. // SaveAccountSettings stores the account settings in DB.
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error { func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error {
result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings}) Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings})
if result.Error != nil { if result.Error != nil {
return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error)

View File

@ -1201,7 +1201,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID) account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err) require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID) users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, users, len(account.Users)) require.Len(t, users, len(account.Users))
} }

View File

@ -72,7 +72,7 @@ type Store interface {
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error

View File

@ -31,6 +31,8 @@ const (
UserIssuedAPI = "api" UserIssuedAPI = "api"
UserIssuedIntegration = "integration" UserIssuedIntegration = "integration"
errUserNotPartOfAccountMsg = "user is not part of this account"
) )
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown // StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown