From 720d36a2901e7eb3143e7bd172b66514b67cfc37 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 18 Sep 2024 15:55:52 +0300 Subject: [PATCH] refactor getAccountWithAuthorizationClaims Signed-off-by: bcmmbaga --- management/server/account.go | 72 +++++++++++++-------------------- management/server/file_store.go | 4 ++ management/server/sql_store.go | 21 ++++++++-- management/server/store.go | 3 +- 4 files changed, 50 insertions(+), 50 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 0108c2758..84bdc629f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1625,21 +1625,13 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai } // redeemInvite checks whether user has been invited and redeems the invite -func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID, userID string) error { +func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error { // only possible with the enabled IdP manager if am.idpManager == nil { log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") return nil } - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - - account, err := am.Store.GetAccount(ctx, accountID) - unlock() - if err != nil { - return err - } - user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err @@ -1759,7 +1751,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } - accountID, err := am.getAccountWithAuthorizationClaims(ctx, claims) + account, err := am.getAccountWithAuthorizationClaims(ctx, claims) if err != nil { return "", "", err } @@ -1771,17 +1763,17 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims } if !user.IsServiceUser && claims.Invited { - err = am.redeemInvite(ctx, accountID, user.Id) + err = am.redeemInvite(ctx, account, user.Id) if err != nil { return "", "", err } } - if err = am.syncJWTGroups(ctx, claims, accountID); err != nil { + if err = am.syncJWTGroups(ctx, claims, account.Id); err != nil { return "", "", err } - return accountID, user.Id, nil + return account.Id, user.Id, nil } // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, @@ -1884,31 +1876,27 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtcl // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { +func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) if claims.UserId == "" { - return "", fmt.Errorf("user ID is empty") + return nil, fmt.Errorf("user ID is empty") } // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - account, err := am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) - if err != nil { - return "", nil - } - return account.Id, nil + return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) } else if claims.AccountId != "" { accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) if err != nil { - return "", err + return nil, err } if _, ok := accountFromID.Users[claims.UserId]; !ok { - return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain { - return accountFromID.Id, nil + return accountFromID, nil } } @@ -1918,53 +1906,47 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) // We checked if the domain has a primary account already - domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, claims.Domain) if err != nil { // if NotFound we are good to continue, otherwise return error e, ok := status.FromError(err) if !ok || e.Type() != status.NotFound { - return "", err + return nil, err } } - account, err := am.Store.GetAccountByUser(ctx, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err == nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) defer unlockAccount() - account, err = am.Store.GetAccountByUser(ctx, claims.UserId) + account, err := am.Store.GetAccountByUser(ctx, claims.UserId) if err != nil { - return "", err + return nil, err } // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, // we compare the account's ID with the domain account ID, and if they don't match, we set the account as // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain // was previously unclassified or classified as public so N users that logged int that time, has they own account // and peers that shouldn't be lost. - primaryDomain := domainAccount == nil || account.Id == domainAccount.Id + primaryDomain := domainAccountID == "" || account.Id == domainAccountID err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) if err != nil { - return "", err + return nil, err } - return account.Id, nil + return account, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - if domainAccount != nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id) - defer unlockAccount() - domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) - if err != nil { - return "", err - } + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) + defer unlockAccount() + domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) + if err != nil { + return nil, err } - account, err = am.handleNewUserAccount(ctx, domainAccount, claims) - if err != nil { - return "", err - } - return account.Id, nil + return am.handleNewUserAccount(ctx, domainAccount, claims) } else { // other error - return "", err + return nil, err } } diff --git a/management/server/file_store.go b/management/server/file_store.go index 1927568ef..ed1bc3d09 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -803,3 +803,7 @@ func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error { func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { return status.Errorf(status.Internal, "SaveGroups is not implemented") } + +func (s *FileStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) { + return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") +} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 0fb3d391f..49a5ddeb4 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -397,20 +397,33 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { } func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { + accountID, err := s.GetAccountIDByPrivateDomain(ctx, domain) + if err != nil { + return nil, err + } + + if accountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + // TODO: rework to not call GetAccount + return s.GetAccount(ctx, accountID) +} + +func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) { var account Account result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", strings.ToLower(domain), true, PrivateCategory) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") + return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting account from store") + return "", status.Errorf(status.Internal, "issue getting account from store") } - // TODO: rework to not call GetAccount - return s.GetAccount(ctx, account.Id) + return account.Id, nil } func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { diff --git a/management/server/store.go b/management/server/store.go index a2b489391..dcff80ee5 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -34,11 +34,12 @@ type Store interface { GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) - GetAccountIDByUserID(peerKey string) (string, error) + GetAccountIDByUserID(userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) + GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, userID string) (*User, error)