mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-22 05:01:24 +01:00
refactor getAccountFromToken
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
e5d55d3c10
commit
ccab3b427f
@ -1625,13 +1625,21 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
|
||||
}
|
||||
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
|
||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID, userID string) error {
|
||||
// only possible with the enabled IdP manager
|
||||
if am.idpManager == nil {
|
||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1751,94 +1759,112 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||
}
|
||||
|
||||
newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
|
||||
alreadyUnlocked := false
|
||||
defer func() {
|
||||
if !alreadyUnlocked {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, newAcc.Id)
|
||||
accountID, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
user := account.Users[claims.UserId]
|
||||
if user == nil {
|
||||
user, err := am.Store.GetUserByUserID(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
// this is not really possible because we got an account by user ID
|
||||
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
}
|
||||
|
||||
if !user.IsServiceUser && claims.Invited {
|
||||
err = am.redeemInvite(ctx, account, claims.UserId)
|
||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
if account.Settings.JWTGroupsEnabled {
|
||||
if account.Settings.JWTGroupsClaimName == "" {
|
||||
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
||||
if err = am.syncJWTGroups(ctx, claims, accountID); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return account.Id, user.Id, nil
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims, accountID string) error {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !settings.JWTGroupsEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if settings.JWTGroupsClaimName == "" {
|
||||
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
||||
return nil
|
||||
}
|
||||
|
||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(claims.UserId)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
oldGroups := make([]string, len(user.AutoGroups))
|
||||
copy(oldGroups, user.AutoGroups)
|
||||
|
||||
// Update the account if group membership changes
|
||||
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
|
||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
||||
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
||||
account.Network.IncSerial()
|
||||
}
|
||||
|
||||
jwtGroupsNames := extractJWTGroups(ctx, account.Settings.JWTGroupsClaimName, claims)
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
oldGroups := make([]string, len(user.AutoGroups))
|
||||
copy(oldGroups, user.AutoGroups)
|
||||
// if groups were added or modified, save the account
|
||||
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
|
||||
if account.Settings.GroupsPropagationEnabled {
|
||||
if user, err := account.FindUser(claims.UserId); err == nil {
|
||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
||||
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
||||
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
||||
account.Network.IncSerial()
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
unlock()
|
||||
alreadyUnlocked = true
|
||||
for _, g := range addNewGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{
|
||||
"group": group.Name,
|
||||
"group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser,
|
||||
"user_name": user.ServiceUserName})
|
||||
}
|
||||
}
|
||||
for _, g := range removeOldGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{
|
||||
"group": group.Name,
|
||||
"group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser,
|
||||
"user_name": user.ServiceUserName})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||
}
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{
|
||||
"group": group.Name,
|
||||
"group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser,
|
||||
"user_name": user.ServiceUserName})
|
||||
}
|
||||
}
|
||||
|
||||
for _, g := range removeOldGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{
|
||||
"group": group.Name,
|
||||
"group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser,
|
||||
"user_name": user.ServiceUserName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return account.Id, user.Id, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
||||
@ -1858,27 +1884,31 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||
//
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||
if claims.UserId == "" {
|
||||
return nil, fmt.Errorf("user ID is empty")
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
|
||||
// if Account ID is part of the claims
|
||||
// it means that we've already classified the domain and user has an account
|
||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||
account, err := am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
return account.Id, nil
|
||||
} else if claims.AccountId != "" {
|
||||
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
if _, ok := accountFromID.Users[claims.UserId]; !ok {
|
||||
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
|
||||
return accountFromID, nil
|
||||
return accountFromID.Id, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -1893,7 +1923,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
// if NotFound we are good to continue, otherwise return error
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
@ -1903,7 +1933,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
defer unlockAccount()
|
||||
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
@ -1914,22 +1944,27 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
|
||||
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
return account, nil
|
||||
return account.Id, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
if domainAccount != nil {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
|
||||
account, err = am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
} else {
|
||||
// other error
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -645,8 +645,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
testCase.inputClaims.AccountId = initAccount.Id
|
||||
}
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
|
||||
accountID, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
|
||||
require.NoError(t, err, "support function failed")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
|
||||
require.NoError(t, err, "get account by account id")
|
||||
|
||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||
|
||||
@ -685,8 +689,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
|
||||
require.NoError(t, err, "get account by account id")
|
||||
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
@ -696,8 +704,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
|
||||
require.NoError(t, err, "get account by account id")
|
||||
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
|
||||
@ -708,8 +720,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
|
||||
require.NoError(t, err, "get account by account id")
|
||||
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
|
||||
groupsByNames := map[string]*group.Group{}
|
||||
|
Loading…
Reference in New Issue
Block a user