diff --git a/management/server/account.go b/management/server/account.go index 178aabd55..fa71c28cd 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1808,7 +1808,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai } } - if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { + if err = am.syncJWTGroups(ctx, accountID, claims); err != nil { return "", "", err } @@ -1817,7 +1817,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. -func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err @@ -1833,6 +1833,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) + hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames) if err != nil { return err @@ -1840,26 +1841,27 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // skip update if no changes if !hasChanges { - log.WithContext(ctx).Debugf("no changes in JWT group membership") return nil } return am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) + user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + return err + } - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) + addNewGroups := difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups := difference(user.AutoGroups, updatedAutoGroups) + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } user.AutoGroups = updatedAutoGroups if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { return fmt.Errorf("error saving user: %w", err) } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { - return fmt.Errorf("error saving groups: %w", err) - } - // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { if err = transaction.AddUserPeersToGroups(ctx, accountID, claims.UserId, addNewGroups); err != nil {