diff --git a/management/server/account.go b/management/server/account.go index 33e7fc11c..fd1fcecec 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -53,7 +53,10 @@ const ( DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerInactivityExpiration = 10 * time.Minute emptyUserID = "empty user ID in claims" - errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + + errGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + errGetAccountFmt = "failed to get account: %w" + errNetworkSerialIncrementFmt = "failed to increment network serial: %w" ) type userLoggedInOnce bool @@ -111,7 +114,6 @@ type AccountManager interface { SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) @@ -1144,7 +1146,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco var oldSettings *Settings err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to get account settings: %w", err) } @@ -1153,7 +1155,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return fmt.Errorf("failed to validate extra settings: %w", err) } - if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil { + if err = transaction.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil { return fmt.Errorf("failed to update account settings: %w", err) } @@ -2049,7 +2051,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error getting user: %w", err) } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2079,7 +2081,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, accountID) + groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2226,7 +2228,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if handleNotFound(err) != nil { - log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + log.WithContext(ctx).Errorf(errGettingDomainAccIDFmt, err) return "", nil, err } @@ -2240,7 +2242,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont // check again if the domain has a primary account because of simultaneous requests domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if handleNotFound(err) != nil { - log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + log.WithContext(ctx).Errorf(errGettingDomainAccIDFmt, err) return "", nil, err } @@ -2271,7 +2273,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context // We checked if the domain has a primary account already domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) if handleNotFound(err) != nil { - log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + log.WithContext(ctx).Errorf(errGettingDomainAccIDFmt, err) return "", err } diff --git a/management/server/account_test.go b/management/server/account_test.go index aaef8fe8f..b65777981 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2612,7 +2612,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") + groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID") assert.NoError(t, err) assert.Len(t, groups, 3, "new group3 should be added") diff --git a/management/server/dns.go b/management/server/dns.go index 12a332156..85cca221a 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -113,7 +113,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return err } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return err } diff --git a/management/server/group.go b/management/server/group.go index 74b7f977f..4bf807d90 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return err } - if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "groups are blocked for users") + if !user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked { + return status.Errorf(status.PermissionDenied, "access to groups is blocked for users") + } + + if user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) } return nil @@ -59,7 +63,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us return nil, err } - return am.Store.GetAccountGroups(ctx, accountID) + return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers @@ -80,7 +84,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "no permission to create group") + return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) } var ( @@ -126,7 +130,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) + return fmt.Errorf(errNetworkSerialIncrementFmt, err) } if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil { @@ -144,7 +148,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - return fmt.Errorf("error getting account: %w", err) + return fmt.Errorf(errGetAccountFmt, err) } am.updateAccountPeers(ctx, account) @@ -229,7 +233,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "no permission to delete group") + return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) } group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) @@ -247,7 +251,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) + return fmt.Errorf(errNetworkSerialIncrementFmt, err) } if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, groupID, accountID); err != nil { @@ -263,7 +267,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - return fmt.Errorf("error getting account: %w", err) + return fmt.Errorf(errGetAccountFmt, err) } am.updateAccountPeers(ctx, account) @@ -278,7 +282,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us } if user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "no permission to delete groups") + return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) } var ( @@ -304,7 +308,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) + return fmt.Errorf(errNetworkSerialIncrementFmt, err) } if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil { @@ -322,46 +326,20 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - return fmt.Errorf("error getting account: %w", err) + return fmt.Errorf(errGetAccountFmt, err) } am.updateAccountPeers(ctx, account) return allErrors } -// ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) - } - - return groups, nil -} - // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - add := true for _, itemID := range group.Peers { if itemID == peerID { @@ -373,11 +351,24 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr group.Peers = append(group.Peers, peerID) } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf(errNetworkSerialIncrementFmt, err) + } + + if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { + return fmt.Errorf("failed to save group: %w", err) + } + return nil + }) + if err != nil { return err } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } am.updateAccountPeers(ctx, account) return nil @@ -385,29 +376,42 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - - account.Network.IncSerial() + updated := false for i, itemID := range group.Peers { if itemID == peerID { group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - if err := am.Store.SaveAccount(ctx, account); err != nil { - return err - } + updated = true + break } } + if !updated { + return nil + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf(errNetworkSerialIncrementFmt, err) + } + + if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { + return fmt.Errorf("failed to save group: %w", err) + } + return nil + }) + if err != nil { + return err + } + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } am.updateAccountPeers(ctx, account) return nil diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 99e6b204c..ba6a20259 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -56,13 +56,15 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId if len(groups) == 0 { return true, nil } - accountsGroups, err := am.ListGroups(ctx, accountId) + + accountGroups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountId) if err != nil { return false, err } + for _, group := range groups { var found bool - for _, accountGroup := range accountsGroups { + for _, accountGroup := range accountGroups { if accountGroup.ID == group { found = true break diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 51eef2af3..a996db298 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -45,7 +45,6 @@ type MockAccountManager struct { SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error @@ -347,14 +346,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") } -// ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { - if am.ListGroupsFunc != nil { - return am.ListGroupsFunc(ctx, accountID) - } - return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented") -} - // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupAddPeerFunc != nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 90c1eefa2..5d2f9d90f 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -216,7 +216,7 @@ func (am *DefaultAccountManager) validateNameServerGroup(ctx context.Context, ac return err } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return err } diff --git a/management/server/peer.go b/management/server/peer.go index 461b9e310..bed3474e7 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -743,7 +743,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } diff --git a/management/server/route_test.go b/management/server/route_test.go index 09cbe53ff..17eac951d 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1090,7 +1090,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.ListGroups(context.Background(), account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index d6bdf74f0..838a70ff6 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -225,7 +225,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s keyDuration = expiresIn } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -278,7 +278,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.Errorf(status.PermissionDenied, "only users with admin power can update setup keys") } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index dde15b265..a4e688259 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -555,9 +555,9 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre return users, nil } -func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { +func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Find(&groups, accountIDCondition, accountID) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -841,7 +841,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS // SaveAccountSettings stores the account settings in DB. func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error { - result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings}) if result.Error != nil { return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 8feec2b77..69d91c980 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1201,7 +1201,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) - users, err := store.GetAccountUsers(context.Background(), accountID) + users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err) require.Len(t, users, len(account.Users)) } diff --git a/management/server/store.go b/management/server/store.go index deb55ea3f..6ab83c458 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -72,7 +72,7 @@ type Store interface { DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error - GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error diff --git a/management/server/user.go b/management/server/user.go index 17c5065ef..cd5e7c334 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -31,6 +31,8 @@ const ( UserIssuedAPI = "api" UserIssuedIntegration = "integration" + + errUserNotPartOfAccountMsg = "user is not part of this account" ) // StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown