mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-04 18:01:13 +01:00
refactor getAccountWithAuthorizationClaims to return account id
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
9631cb4fb3
commit
4d9bb7ea35
@ -75,7 +75,7 @@ type AccountManager interface {
|
||||
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
|
||||
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
|
||||
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
@ -1252,25 +1252,30 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
|
||||
// GetAccountIDByUserOrAccountID looks for an account by user or accountID, if no account is provided and
|
||||
// userID doesn't have an account associated with it, one account is created
|
||||
// domain is used to create a new account if no account is found
|
||||
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) {
|
||||
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
||||
if accountID != "" {
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
_, _, err := am.Store.GetAccountDomainAndCategory(ctx, accountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accountID, nil
|
||||
} else if userID != "" {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
|
||||
return "", status.Errorf(status.NotFound, "account not found using user id: %s", userID)
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
return account, nil
|
||||
return account.Id, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
|
||||
return "", status.Errorf(status.NotFound, "no valid user or account Id provided")
|
||||
}
|
||||
|
||||
func isNil(i idp.Manager) bool {
|
||||
@ -1613,13 +1618,21 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
|
||||
}
|
||||
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
|
||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
|
||||
// only possible with the enabled IdP manager
|
||||
if am.idpManager == nil {
|
||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1739,7 +1752,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||
}
|
||||
|
||||
account, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -1751,26 +1764,28 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
}
|
||||
|
||||
if !user.IsServiceUser && claims.Invited {
|
||||
err = am.redeemInvite(ctx, account, user.Id)
|
||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
if err = am.syncJWTGroups(ctx, account, user, claims); err != nil {
|
||||
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return account, user, nil
|
||||
// TODO: return account id, user id and error
|
||||
return &Account{Id: accountID}, user, nil
|
||||
}
|
||||
|
||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, account *Account, user *User, claims jwtclaims.AuthorizationClaims) error {
|
||||
settings := account.Settings
|
||||
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings == nil || !settings.JWTGroupsEnabled {
|
||||
return nil
|
||||
}
|
||||
@ -1780,6 +1795,14 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, account *Acc
|
||||
return nil
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
oldGroups := make([]string, len(user.AutoGroups))
|
||||
@ -1833,7 +1856,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, account *Acc
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
||||
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||
// if domain is of the PrivateCategory category, it will evaluate
|
||||
// if account is new, existing or if there is another account with the same domain
|
||||
//
|
||||
@ -1850,27 +1873,34 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, account *Acc
|
||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||
//
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||
if claims.UserId == "" {
|
||||
return nil, fmt.Errorf("user ID is empty")
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
|
||||
// if Account ID is part of the claims
|
||||
// it means that we've already classified the domain and user has an account
|
||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||
} else if claims.AccountId != "" {
|
||||
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
if _, ok := accountFromID.Users[claims.UserId]; !ok {
|
||||
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
|
||||
return accountFromID, nil
|
||||
|
||||
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, claims.AccountId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
|
||||
return userAccountID, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -1885,7 +1915,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
// if NotFound we are good to continue, otherwise return error
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
@ -1895,7 +1925,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
defer unlockAccount()
|
||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
@ -1903,12 +1933,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
||||
|
||||
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account, nil
|
||||
|
||||
return account.Id, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
var domainAccount *Account
|
||||
if domainAccountID != "" {
|
||||
@ -1916,14 +1945,18 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
} else {
|
||||
// other error
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -931,7 +931,7 @@ func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
|
||||
}
|
||||
|
||||
@ -950,14 +950,18 @@ func (s *FileStore) GetStoreEngine() StoreEngine {
|
||||
return FileStoreEngine
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
|
||||
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
||||
func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
|
||||
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) {
|
||||
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (string, error) {
|
||||
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ string) (string, string, error) {
|
||||
return "", "", status.Errorf(status.Internal, "GetAccountDomainAndCategory is not implemented")
|
||||
}
|
||||
|
@ -1033,3 +1033,18 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
db: tx,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
|
||||
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error) {
|
||||
var account Account
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Select("domain", "domain_category").
|
||||
Where(idQueryCondition, accountID).First(&account)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", "", status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
return "", "", status.Errorf(status.Internal, "failed to retrieve account fields")
|
||||
}
|
||||
|
||||
return account.Domain, account.DomainCategory, nil
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ const (
|
||||
type Store interface {
|
||||
GetAllAccounts(ctx context.Context) []*Account
|
||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||
GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error)
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||
|
@ -360,16 +360,11 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
// GetUser looks up a user by provided authorization claims.
|
||||
// It will also create an account if didn't exist for this user before.
|
||||
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
||||
account, user, err := am.GetAccountFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||
}
|
||||
|
||||
user, ok := account.Users[claims.UserId]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
// this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC
|
||||
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
||||
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
||||
|
Loading…
Reference in New Issue
Block a user