mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
[management] Refactor getAccountIDWithAuthorizationClaims (#2715)
This change restructures the getAccountIDWithAuthorizationClaims method to improve readability, maintainability, and performance. - have dedicated methods to handle possible cases - introduced Store.UpdateAccountDomainAttributes and Store.GetAccountUsers methods - Remove GetAccount and SaveAccount dependency - added tests
This commit is contained in:
parent
0e95f16cdd
commit
da3a053e2b
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"math/rand"
|
||||
@ -50,6 +51,8 @@ const (
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
emptyUserID = "empty user ID in claims"
|
||||
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||
)
|
||||
|
||||
type userLoggedInOnce bool
|
||||
@ -1285,7 +1288,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
@ -1300,28 +1303,39 @@ func isNil(i idp.Manager) bool {
|
||||
}
|
||||
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error {
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cachedAccount := &Account{
|
||||
Id: accountID,
|
||||
Users: make(map[string]*User),
|
||||
}
|
||||
for _, user := range accountUsers {
|
||||
cachedAccount.Users[user.Id] = user
|
||||
}
|
||||
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user != nil && user.AppMetadata.WTAccountID == account.Id {
|
||||
if user != nil && user.AppMetadata.WTAccountID == accountID {
|
||||
// it was already set, so we skip the unnecessary update
|
||||
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
||||
account.Id, userID)
|
||||
accountID, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id})
|
||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
||||
}
|
||||
// refresh cache to reflect the update
|
||||
_, err = am.refreshCache(ctx, account.Id)
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1545,48 +1559,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
|
||||
return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration()))
|
||||
}
|
||||
|
||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims,
|
||||
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
|
||||
primaryDomain bool,
|
||||
) error {
|
||||
|
||||
if claims.Domain != "" {
|
||||
account.IsDomainPrimaryAccount = primaryDomain
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
userObj := account.Users[claims.UserId]
|
||||
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
|
||||
account.Domain = lowerDomain
|
||||
}
|
||||
// prevent updating category for different domain until admin logs in
|
||||
if account.Domain == lowerDomain {
|
||||
account.DomainCategory = claims.DomainCategory
|
||||
}
|
||||
} else {
|
||||
if claims.Domain == "" {
|
||||
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := am.Store.SaveAccount(ctx, account)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlockAccount()
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting user: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
newDomain := accountDomain
|
||||
newCategoty := domainCategory
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
if accountDomain != lowerDomain && user.HasAdminPower() {
|
||||
newDomain = lowerDomain
|
||||
}
|
||||
|
||||
if accountDomain == lowerDomain {
|
||||
newCategoty = claims.DomainCategory
|
||||
}
|
||||
|
||||
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
|
||||
}
|
||||
|
||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
ctx context.Context,
|
||||
existingAcc *Account,
|
||||
primaryDomain bool,
|
||||
userAccountID string,
|
||||
domainAccountID string,
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
) error {
|
||||
err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain)
|
||||
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
||||
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1594,44 +1629,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// otherwise it will create a new account and make it primary account for the domain.
|
||||
func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
if claims.UserId == "" {
|
||||
return nil, fmt.Errorf("user ID is empty")
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
var (
|
||||
account *Account
|
||||
err error
|
||||
)
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
// if domain already has a primary account, add regular user
|
||||
if domainAcc != nil {
|
||||
account = domainAcc
|
||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
account, err = am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.updateAccountDomainAttributes(ctx, account, claims, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account)
|
||||
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil)
|
||||
newAccount.Domain = lowerDomain
|
||||
newAccount.DomainCategory = claims.DomainCategory
|
||||
newAccount.IsDomainPrimaryAccount = true
|
||||
|
||||
return account, nil
|
||||
err = am.Store.SaveAccount(ctx, newAccount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||
|
||||
return newAccount.Id, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
|
||||
usersMap := make(map[string]*User)
|
||||
usersMap[claims.UserId] = NewRegularUser(claims.UserId)
|
||||
err := am.Store.SaveUsers(domainAccountID, usersMap)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
|
||||
return domainAccountID, nil
|
||||
}
|
||||
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
@ -1775,7 +1824,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
|
||||
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
if claims.UserId == "" {
|
||||
return "", "", fmt.Errorf("user ID is empty")
|
||||
return "", "", errors.New(emptyUserID)
|
||||
}
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
@ -1961,16 +2010,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||
// if domain is not private or domain is invalid, it will return the account ID by user ID.
|
||||
// if domain is of the PrivateCategory category, it will evaluate
|
||||
// if account is new, existing or if there is another account with the same domain
|
||||
//
|
||||
// Use cases:
|
||||
//
|
||||
// New user + New account + New domain -> create account, user role = admin (if private domain, index domain)
|
||||
// New user + New account + New domain -> create account, user role = owner (if private domain, index domain)
|
||||
//
|
||||
// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin)
|
||||
// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin)
|
||||
//
|
||||
// New user + New account + Existing Public Domain -> create account, user role = admin
|
||||
// New user + New account + Existing Public Domain -> create account, user role = owner
|
||||
//
|
||||
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
|
||||
//
|
||||
@ -1980,98 +2030,123 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||
|
||||
if claims.UserId == "" {
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
return "", errors.New(emptyUserID)
|
||||
}
|
||||
|
||||
// if Account ID is part of the claims
|
||||
// it means that we've already classified the domain and user has an account
|
||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||
if claims.AccountId != "" {
|
||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !exists {
|
||||
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
|
||||
}
|
||||
return claims.AccountId, nil
|
||||
}
|
||||
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
||||
} else if claims.AccountId != "" {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
|
||||
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
|
||||
return userAccountID, nil
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
unlock := am.Store.AcquireGlobalLock(ctx)
|
||||
defer unlock()
|
||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
||||
if claims.AccountId != "" {
|
||||
return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
if err != nil {
|
||||
// if NotFound we are good to continue, otherwise return error
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return "", err
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err == nil {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
|
||||
defer unlockAccount()
|
||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
||||
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return account.Id, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
var domainAccount *Account
|
||||
if domainAccountID != "" {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
} else {
|
||||
// other error
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != "" {
|
||||
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return userAccountID, nil
|
||||
}
|
||||
|
||||
if domainAccountID != "" {
|
||||
return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
|
||||
}
|
||||
|
||||
return am.addNewPrivateAccount(ctx, domainAccountID, claims)
|
||||
}
|
||||
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if domainAccountID != "" {
|
||||
return domainAccountID, nil, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain)
|
||||
cancel := am.Store.AcquireGlobalLock(ctx)
|
||||
|
||||
// check again if the domain has a primary account because of simultaneous requests
|
||||
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return domainAccountID, cancel, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||
return claims.AccountId, nil
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return claims.AccountId, nil
|
||||
}
|
||||
|
||||
func handleNotFound(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
|
||||
return claims.Domain != "" && claims.Domain != domain && claims.DomainCategory == PrivateCategory && domainCategory != PrivateCategory
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
|
@ -465,7 +465,26 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
type initUserParams jwtclaims.AuthorizationClaims
|
||||
|
||||
type test struct {
|
||||
var (
|
||||
publicDomain = "public.com"
|
||||
privateDomain = "private.com"
|
||||
unknownDomain = "unknown.com"
|
||||
)
|
||||
|
||||
defaultInitAccount := initUserParams{
|
||||
Domain: publicDomain,
|
||||
UserId: "defaultUser",
|
||||
}
|
||||
|
||||
initUnknown := defaultInitAccount
|
||||
initUnknown.DomainCategory = UnknownCategory
|
||||
initUnknown.Domain = unknownDomain
|
||||
|
||||
privateInitAccount := defaultInitAccount
|
||||
privateInitAccount.Domain = privateDomain
|
||||
privateInitAccount.DomainCategory = PrivateCategory
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputClaims jwtclaims.AuthorizationClaims
|
||||
inputInitUserParams initUserParams
|
||||
@ -479,156 +498,131 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
expectedPrimaryDomainStatus bool
|
||||
expectedCreatedBy string
|
||||
expectedUsers []string
|
||||
}
|
||||
|
||||
var (
|
||||
publicDomain = "public.com"
|
||||
privateDomain = "private.com"
|
||||
unknownDomain = "unknown.com"
|
||||
)
|
||||
|
||||
defaultInitAccount := initUserParams{
|
||||
Domain: publicDomain,
|
||||
UserId: "defaultUser",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "New User With Public Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: publicDomain,
|
||||
UserId: "pub-domain-user",
|
||||
DomainCategory: PublicCategory,
|
||||
}{
|
||||
{
|
||||
name: "New User With Public Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: publicDomain,
|
||||
UserId: "pub-domain-user",
|
||||
DomainCategory: PublicCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomainCategory: "",
|
||||
expectedDomain: publicDomain,
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "pub-domain-user",
|
||||
expectedUsers: []string{"pub-domain-user"},
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomainCategory: "",
|
||||
expectedDomain: publicDomain,
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "pub-domain-user",
|
||||
expectedUsers: []string{"pub-domain-user"},
|
||||
}
|
||||
|
||||
initUnknown := defaultInitAccount
|
||||
initUnknown.DomainCategory = UnknownCategory
|
||||
initUnknown.Domain = unknownDomain
|
||||
|
||||
testCase2 := test{
|
||||
name: "New User With Unknown Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: unknownDomain,
|
||||
UserId: "unknown-domain-user",
|
||||
DomainCategory: UnknownCategory,
|
||||
{
|
||||
name: "New User With Unknown Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: unknownDomain,
|
||||
UserId: "unknown-domain-user",
|
||||
DomainCategory: UnknownCategory,
|
||||
},
|
||||
inputInitUserParams: initUnknown,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: unknownDomain,
|
||||
expectedDomainCategory: "",
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "unknown-domain-user",
|
||||
expectedUsers: []string{"unknown-domain-user"},
|
||||
},
|
||||
inputInitUserParams: initUnknown,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: unknownDomain,
|
||||
expectedDomainCategory: "",
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "unknown-domain-user",
|
||||
expectedUsers: []string{"unknown-domain-user"},
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
name: "New User With Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
{
|
||||
name: "New User With Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: privateDomain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: "pvt-domain-user",
|
||||
expectedUsers: []string{"pvt-domain-user"},
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: privateDomain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: "pvt-domain-user",
|
||||
expectedUsers: []string{"pvt-domain-user"},
|
||||
}
|
||||
|
||||
privateInitAccount := defaultInitAccount
|
||||
privateInitAccount.Domain = privateDomain
|
||||
privateInitAccount.DomainCategory = PrivateCategory
|
||||
|
||||
testCase4 := test{
|
||||
name: "New Regular User With Existing Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "new-pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
{
|
||||
name: "New Regular User With Existing Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: privateDomain,
|
||||
UserId: "new-pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputUpdateAttrs: true,
|
||||
inputInitUserParams: privateInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleUser,
|
||||
expectedDomain: privateDomain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
|
||||
},
|
||||
inputUpdateAttrs: true,
|
||||
inputInitUserParams: privateInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleUser,
|
||||
expectedDomain: privateDomain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
|
||||
}
|
||||
|
||||
testCase5 := test{
|
||||
name: "Existing User With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
{
|
||||
name: "Existing User With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: defaultInitAccount.Domain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId},
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: defaultInitAccount.Domain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId},
|
||||
}
|
||||
|
||||
testCase6 := test{
|
||||
name: "Existing Account Id With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
{
|
||||
name: "Existing Account Id With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputUpdateClaimAccount: true,
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: defaultInitAccount.Domain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId},
|
||||
},
|
||||
inputUpdateClaimAccount: true,
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.Equal,
|
||||
expectedMSG: "account IDs should match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: defaultInitAccount.Domain,
|
||||
expectedDomainCategory: PrivateCategory,
|
||||
expectedPrimaryDomainStatus: true,
|
||||
expectedCreatedBy: defaultInitAccount.UserId,
|
||||
expectedUsers: []string{defaultInitAccount.UserId},
|
||||
}
|
||||
|
||||
testCase7 := test{
|
||||
name: "User With Private Category And Empty Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: "",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
{
|
||||
name: "User With Private Category And Empty Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
Domain: "",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: "",
|
||||
expectedDomainCategory: "",
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "pvt-domain-user",
|
||||
expectedUsers: []string{"pvt-domain-user"},
|
||||
},
|
||||
inputInitUserParams: defaultInitAccount,
|
||||
testingFunc: require.NotEqual,
|
||||
expectedMSG: "account IDs shouldn't match",
|
||||
expectedUserRole: UserRoleOwner,
|
||||
expectedDomain: "",
|
||||
expectedDomainCategory: "",
|
||||
expectedPrimaryDomainStatus: false,
|
||||
expectedCreatedBy: "pvt-domain-user",
|
||||
expectedUsers: []string{"pvt-domain-user"},
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
if testCase.inputUpdateAttrs {
|
||||
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
require.NoError(t, err, "update init user failed")
|
||||
}
|
||||
|
||||
|
@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
|
||||
accountCopy := Account{
|
||||
Domain: domain,
|
||||
DomainCategory: category,
|
||||
IsDomainPrimaryAccount: isPrimaryDomain,
|
||||
}
|
||||
|
||||
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).
|
||||
Select(fieldsToUpdate).
|
||||
Where(idQueryCondition, accountID).
|
||||
Updates(&accountCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "account %s", accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
var users []*User
|
||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting users from store")
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||
|
@ -1191,3 +1191,63 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
users, err := store.GetAccountUsers(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, len(account.Users))
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
t.Run("Should update attributes with public domain", func(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
domain := "example.com"
|
||||
category := "public"
|
||||
IsDomainPrimaryAccount := false
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||
require.NoError(t, err)
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, domain, account.Domain)
|
||||
require.Equal(t, category, account.DomainCategory)
|
||||
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||
})
|
||||
|
||||
t.Run("Should update attributes with private domain", func(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
domain := "test.com"
|
||||
category := "private"
|
||||
IsDomainPrimaryAccount := true
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||
require.NoError(t, err)
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, domain, account.Domain)
|
||||
require.Equal(t, category, account.DomainCategory)
|
||||
require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||
})
|
||||
|
||||
t.Run("Should fail when account does not exist", func(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
domain := "test.com"
|
||||
category := "private"
|
||||
IsDomainPrimaryAccount := true
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
@ -58,9 +58,11 @@ type Store interface {
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
|
Loading…
Reference in New Issue
Block a user