refactor getAccountWithAuthorizationClaims

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-09-18 15:55:52 +03:00
parent ccab3b427f
commit 720d36a290
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
4 changed files with 50 additions and 50 deletions

View File

@ -1625,21 +1625,13 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
} }
// redeemInvite checks whether user has been invited and redeems the invite // redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID, userID string) error { func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
// only possible with the enabled IdP manager // only possible with the enabled IdP manager
if am.idpManager == nil { if am.idpManager == nil {
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
return nil return nil
} }
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
account, err := am.Store.GetAccount(ctx, accountID)
unlock()
if err != nil {
return err
}
user, err := am.lookupUserInCache(ctx, userID, account) user, err := am.lookupUserInCache(ctx, userID, account)
if err != nil { if err != nil {
return err return err
@ -1759,7 +1751,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
} }
accountID, err := am.getAccountWithAuthorizationClaims(ctx, claims) account, err := am.getAccountWithAuthorizationClaims(ctx, claims)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -1771,17 +1763,17 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
} }
if !user.IsServiceUser && claims.Invited { if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, accountID, user.Id) err = am.redeemInvite(ctx, account, user.Id)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
} }
if err = am.syncJWTGroups(ctx, claims, accountID); err != nil { if err = am.syncJWTGroups(ctx, claims, account.Id); err != nil {
return "", "", err return "", "", err
} }
return accountID, user.Id, nil return account.Id, user.Id, nil
} }
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
@ -1884,31 +1876,27 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtcl
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes // Existing user + Existing account + Existing Indexed Domain -> Nothing changes
// //
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
if claims.UserId == "" { if claims.UserId == "" {
return "", fmt.Errorf("user ID is empty") return nil, fmt.Errorf("user ID is empty")
} }
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
account, err := am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
if err != nil {
return "", nil
}
return account.Id, nil
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
if err != nil { if err != nil {
return "", err return nil, err
} }
if _, ok := accountFromID.Users[claims.UserId]; !ok { if _, ok := accountFromID.Users[claims.UserId]; !ok {
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
} }
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain { if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
return accountFromID.Id, nil return accountFromID, nil
} }
} }
@ -1918,53 +1906,47 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, claims.Domain)
if err != nil { if err != nil {
// if NotFound we are good to continue, otherwise return error // if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err) e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound { if !ok || e.Type() != status.NotFound {
return "", err return nil, err
} }
} }
account, err := am.Store.GetAccountByUser(ctx, claims.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err == nil { if err == nil {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
defer unlockAccount() defer unlockAccount()
account, err = am.Store.GetAccountByUser(ctx, claims.UserId) account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil { if err != nil {
return "", err return nil, err
} }
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as // we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account // was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost. // and peers that shouldn't be lost.
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id primaryDomain := domainAccountID == "" || account.Id == domainAccountID
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
if err != nil { if err != nil {
return "", err return nil, err
} }
return account.Id, nil return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil { unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id) defer unlockAccount()
defer unlockAccount() domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil {
if err != nil { return nil, err
return "", err
}
} }
account, err = am.handleNewUserAccount(ctx, domainAccount, claims) return am.handleNewUserAccount(ctx, domainAccount, claims)
if err != nil {
return "", err
}
return account.Id, nil
} else { } else {
// other error // other error
return "", err return nil, err
} }
} }

View File

@ -803,3 +803,7 @@ func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented") return status.Errorf(status.Internal, "SaveGroups is not implemented")
} }
func (s *FileStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) {
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
}

View File

@ -397,20 +397,33 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
} }
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, domain)
if err != nil {
return nil, err
}
if accountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
// TODO: rework to not call GetAccount
return s.GetAccount(ctx, accountID)
}
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) {
var account Account var account Account
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory) strings.ToLower(domain), true, PrivateCategory)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
} }
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store") return "", status.Errorf(status.Internal, "issue getting account from store")
} }
// TODO: rework to not call GetAccount return account.Id, nil
return s.GetAccount(ctx, account.Id)
} }
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {

View File

@ -34,11 +34,12 @@ type Store interface {
GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(peerKey string) (string, error) GetAccountIDByUserID(userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, userID string) (*User, error) GetUserByUserID(ctx context.Context, userID string) (*User, error)