Remove get account from groups ops

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-16 16:04:34 +03:00
parent 1123729c1c
commit d7c63d5c04
14 changed files with 87 additions and 86 deletions

View File

@ -53,7 +53,10 @@ const (
DefaultPeerLoginExpiration = 24 * time.Hour
DefaultPeerInactivityExpiration = 10 * time.Minute
emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
errGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
errGetAccountFmt = "failed to get account: %w"
errNetworkSerialIncrementFmt = "failed to increment network serial: %w"
)
type userLoggedInOnce bool
@ -111,7 +114,6 @@ type AccountManager interface {
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
@ -1144,7 +1146,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
var oldSettings *Settings
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
@ -1153,7 +1155,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return fmt.Errorf("failed to validate extra settings: %w", err)
}
if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil {
if err = transaction.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil {
return fmt.Errorf("failed to update account settings: %w", err)
}
@ -2049,7 +2051,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user: %w", err)
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@ -2079,7 +2081,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, accountID)
groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@ -2226,7 +2228,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
log.WithContext(ctx).Errorf(errGettingDomainAccIDFmt, err)
return "", nil, err
}
@ -2240,7 +2242,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
// check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
log.WithContext(ctx).Errorf(errGettingDomainAccIDFmt, err)
return "", nil, err
}
@ -2271,7 +2273,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
// We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
log.WithContext(ctx).Errorf(errGettingDomainAccIDFmt, err)
return "", err
}

View File

@ -2612,7 +2612,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")

View File

@ -113,7 +113,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return err
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}

View File

@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return err
}
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
if !user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked {
return status.Errorf(status.PermissionDenied, "access to groups is blocked for users")
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
return nil
@ -59,7 +63,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
return nil, err
}
return am.Store.GetAccountGroups(ctx, accountID)
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
@ -80,7 +84,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to create group")
return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
var (
@ -126,7 +130,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil {
@ -144,7 +148,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
@ -229,7 +233,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to delete group")
return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
@ -247,7 +251,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, groupID, accountID); err != nil {
@ -263,7 +267,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
@ -278,7 +282,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to delete groups")
return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
var (
@ -304,7 +308,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
@ -322,46 +326,20 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
return allErrors
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
@ -373,11 +351,24 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
})
if err != nil {
return err
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
return nil
@ -385,29 +376,42 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
updated := false
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
updated = true
break
}
}
if !updated {
return nil
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
})
if err != nil {
return err
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
return nil

View File

@ -56,13 +56,15 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
if len(groups) == 0 {
return true, nil
}
accountsGroups, err := am.ListGroups(ctx, accountId)
accountGroups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountId)
if err != nil {
return false, err
}
for _, group := range groups {
var found bool
for _, accountGroup := range accountsGroups {
for _, accountGroup := range accountGroups {
if accountGroup.ID == group {
found = true
break

View File

@ -45,7 +45,6 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
@ -347,14 +346,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
}
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil {

View File

@ -216,7 +216,7 @@ func (am *DefaultAccountManager) validateNameServerGroup(ctx context.Context, ac
return err
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}

View File

@ -743,7 +743,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}

View File

@ -1090,7 +1090,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.ListGroups(context.Background(), account.Id)
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups {

View File

@ -225,7 +225,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
keyDuration = expiresIn
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
@ -278,7 +278,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can update setup keys")
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}

View File

@ -555,9 +555,9 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
return users, nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@ -841,7 +841,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
// SaveAccountSettings stores the account settings in DB.
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error {
result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings})
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error)

View File

@ -1201,7 +1201,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID)
users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, users, len(account.Users))
}

View File

@ -72,7 +72,7 @@ type Store interface {
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error

View File

@ -31,6 +31,8 @@ const (
UserIssuedAPI = "api"
UserIssuedIntegration = "integration"
errUserNotPartOfAccountMsg = "user is not part of this account"
)
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown