refactor getAccountFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-09-18 14:24:39 +03:00
parent e5d55d3c10
commit ccab3b427f
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
2 changed files with 135 additions and 84 deletions

View File

@ -1625,13 +1625,21 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
}
// redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID, userID string) error {
// only possible with the enabled IdP manager
if am.idpManager == nil {
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
return nil
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
account, err := am.Store.GetAccount(ctx, accountID)
unlock()
if err != nil {
return err
}
user, err := am.lookupUserInCache(ctx, userID, account)
if err != nil {
return err
@ -1751,94 +1759,112 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims)
if err != nil {
return "", "", err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
alreadyUnlocked := false
defer func() {
if !alreadyUnlocked {
unlock()
}
}()
account, err := am.Store.GetAccount(ctx, newAcc.Id)
accountID, err := am.getAccountWithAuthorizationClaims(ctx, claims)
if err != nil {
return "", "", err
}
user := account.Users[claims.UserId]
if user == nil {
user, err := am.Store.GetUserByUserID(ctx, claims.UserId)
if err != nil {
// this is not really possible because we got an account by user ID
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}
if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, account, claims.UserId)
err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil {
return "", "", err
}
}
if account.Settings.JWTGroupsEnabled {
if account.Settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
if err = am.syncJWTGroups(ctx, claims, accountID); err != nil {
return "", "", err
}
return account.Id, user.Id, nil
return accountID, user.Id, nil
}
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims, accountID string) error {
settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return err
}
if !settings.JWTGroupsEnabled {
return nil
}
if settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
return nil
}
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(claims.UserId)
if err != nil {
return nil
}
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// Update the account if group membership changes
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
if settings.GroupsPropagationEnabled {
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
}
jwtGroupsNames := extractJWTGroups(ctx, account.Settings.JWTGroupsClaimName, claims)
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
return nil
}
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// if groups were added or modified, save the account
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
if account.Settings.GroupsPropagationEnabled {
if user, err := account.FindUser(claims.UserId); err == nil {
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
} else {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
unlock()
alreadyUnlocked = true
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
}
}
} else {
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
}
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
}
return account.Id, user.Id, nil
return nil
}
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
@ -1858,27 +1884,31 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
return "", fmt.Errorf("user ID is empty")
}
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
account, err := am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
if err != nil {
return "", nil
}
return account.Id, nil
} else if claims.AccountId != "" {
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
if err != nil {
return nil, err
return "", err
}
if _, ok := accountFromID.Users[claims.UserId]; !ok {
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
return accountFromID, nil
return accountFromID.Id, nil
}
}
@ -1893,7 +1923,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return nil, err
return "", err
}
}
@ -1903,7 +1933,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
defer unlockAccount()
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil {
return nil, err
return "", err
}
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
@ -1914,22 +1944,27 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
if err != nil {
return nil, err
return "", err
}
return account, nil
return account.Id, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
return nil, err
return "", err
}
}
return am.handleNewUserAccount(ctx, domainAccount, claims)
account, err = am.handleNewUserAccount(ctx, domainAccount, claims)
if err != nil {
return "", err
}
return account.Id, nil
} else {
// other error
return nil, err
return "", err
}
}

View File

@ -645,8 +645,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
accountID, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "get account by account id")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@ -685,8 +689,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}
t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "get account by account id")
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
@ -696,8 +704,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "get account by account id")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
@ -708,8 +720,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "get account by account id")
require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{}