mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-31 10:31:58 +01:00
refactor getAccountWithAuthorizationClaims
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
ccab3b427f
commit
720d36a290
@ -1625,21 +1625,13 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
|
|||||||
}
|
}
|
||||||
|
|
||||||
// redeemInvite checks whether user has been invited and redeems the invite
|
// redeemInvite checks whether user has been invited and redeems the invite
|
||||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID, userID string) error {
|
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
|
||||||
// only possible with the enabled IdP manager
|
// only possible with the enabled IdP manager
|
||||||
if am.idpManager == nil {
|
if am.idpManager == nil {
|
||||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
unlock()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -1759,7 +1751,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
|||||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
accountID, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
account, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
@ -1771,17 +1763,17 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsServiceUser && claims.Invited {
|
if !user.IsServiceUser && claims.Invited {
|
||||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
err = am.redeemInvite(ctx, account, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.syncJWTGroups(ctx, claims, accountID); err != nil {
|
if err = am.syncJWTGroups(ctx, claims, account.Id); err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return accountID, user.Id, nil
|
return account.Id, user.Id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
@ -1884,31 +1876,27 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtcl
|
|||||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||||
//
|
//
|
||||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return "", fmt.Errorf("user ID is empty")
|
return nil, fmt.Errorf("user ID is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if Account ID is part of the claims
|
// if Account ID is part of the claims
|
||||||
// it means that we've already classified the domain and user has an account
|
// it means that we've already classified the domain and user has an account
|
||||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||||
account, err := am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||||
if err != nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return account.Id, nil
|
|
||||||
} else if claims.AccountId != "" {
|
} else if claims.AccountId != "" {
|
||||||
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
|
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
if _, ok := accountFromID.Users[claims.UserId]; !ok {
|
if _, ok := accountFromID.Users[claims.UserId]; !ok {
|
||||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||||
}
|
}
|
||||||
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
|
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
|
||||||
return accountFromID.Id, nil
|
return accountFromID, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1918,53 +1906,47 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
|||||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
||||||
|
|
||||||
// We checked if the domain has a primary account already
|
// We checked if the domain has a primary account already
|
||||||
domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, claims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// if NotFound we are good to continue, otherwise return error
|
// if NotFound we are good to continue, otherwise return error
|
||||||
e, ok := status.FromError(err)
|
e, ok := status.FromError(err)
|
||||||
if !ok || e.Type() != status.NotFound {
|
if !ok || e.Type() != status.NotFound {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
|
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||||
// and peers that shouldn't be lost.
|
// and peers that shouldn't be lost.
|
||||||
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
|
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
||||||
|
|
||||||
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
|
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
return account.Id, nil
|
return account, nil
|
||||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
if domainAccount != nil {
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
|
defer unlockAccount()
|
||||||
defer unlockAccount()
|
domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
if err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err = am.handleNewUserAccount(ctx, domainAccount, claims)
|
return am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return account.Id, nil
|
|
||||||
} else {
|
} else {
|
||||||
// other error
|
// other error
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -803,3 +803,7 @@ func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
|
|||||||
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
||||||
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) {
|
||||||
|
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
|
||||||
|
}
|
||||||
|
@ -397,20 +397,33 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
|
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
|
||||||
|
accountID, err := s.GetAccountIDByPrivateDomain(ctx, domain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountID == "" {
|
||||||
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: rework to not call GetAccount
|
||||||
|
return s.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) {
|
||||||
var account Account
|
var account Account
|
||||||
|
|
||||||
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||||
strings.ToLower(domain), true, PrivateCategory)
|
strings.ToLower(domain), true, PrivateCategory)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: rework to not call GetAccount
|
return account.Id, nil
|
||||||
return s.GetAccount(ctx, account.Id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
|
||||||
|
@ -34,11 +34,12 @@ type Store interface {
|
|||||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||||
GetAccountIDByUserID(peerKey string) (string, error)
|
GetAccountIDByUserID(userID string) (string, error)
|
||||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||||
|
GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error)
|
||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, userID string) (*User, error)
|
||||||
|
Loading…
Reference in New Issue
Block a user