mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-19 11:20:18 +02:00
Refactor GetAccountIDByUserOrAccountID
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -76,7 +76,7 @@ type AccountManager interface {
|
|||||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
|
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
|
||||||
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
|
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||||
@@ -1260,37 +1260,31 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
|
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
|
||||||
// If an accountID is provided, it checks if the account exists and returns it.
|
// If user does have an account, it returns the user's account ID.
|
||||||
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
|
|
||||||
// If the user doesn't have an account, it creates one using the provided domain.
|
// If the user doesn't have an account, it creates one using the provided domain.
|
||||||
// Returns the account ID or an error if none is found or created.
|
// Returns the account ID or an error if none is found or created.
|
||||||
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) {
|
||||||
if accountID != "" {
|
if userID == "" {
|
||||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
return "", status.Errorf(status.NotFound, "no valid userID provided")
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
|
|
||||||
}
|
|
||||||
return accountID, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if userID != "" {
|
accountID, err := am.Store.GetAccountIDByUserID(userID)
|
||||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
if err != nil {
|
||||||
if err != nil {
|
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||||
}
|
if err != nil {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||||
|
}
|
||||||
|
|
||||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
}
|
||||||
|
return account.Id, nil
|
||||||
}
|
}
|
||||||
|
return "", err
|
||||||
return account.Id, nil
|
|
||||||
}
|
}
|
||||||
|
return accountID, nil
|
||||||
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNil(i idp.Manager) bool {
|
func isNil(i idp.Manager) bool {
|
||||||
@@ -1794,6 +1788,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
|
|||||||
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID {
|
||||||
|
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
if !user.IsServiceUser && claims.Invited {
|
if !user.IsServiceUser && claims.Invited {
|
||||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1914,7 +1912,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
|||||||
// if Account ID is part of the claims
|
// if Account ID is part of the claims
|
||||||
// it means that we've already classified the domain and user has an account
|
// it means that we've already classified the domain and user has an account
|
||||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||||
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
if claims.AccountId != "" {
|
||||||
|
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
|
||||||
|
}
|
||||||
|
return claims.AccountId, nil
|
||||||
|
}
|
||||||
|
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
||||||
} else if claims.AccountId != "" {
|
} else if claims.AccountId != "" {
|
||||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -2227,7 +2235,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
|||||||
routes := make(map[route.ID]*route.Route)
|
routes := make(map[route.ID]*route.Route)
|
||||||
setupKeys := map[string]*SetupKey{}
|
setupKeys := map[string]*SetupKey{}
|
||||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||||
users[userID] = NewOwnerUser(userID)
|
|
||||||
|
owner := NewOwnerUser(userID)
|
||||||
|
owner.AccountID = accountID
|
||||||
|
users[userID] = owner
|
||||||
|
|
||||||
dnsSettings := DNSSettings{
|
dnsSettings := DNSSettings{
|
||||||
DisabledManagementGroups: make([]string, 0),
|
DisabledManagementGroups: make([]string, 0),
|
||||||
}
|
}
|
||||||
|
@@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
|
||||||
require.NoError(t, err, "create init user failed")
|
require.NoError(t, err, "create init user failed")
|
||||||
|
|
||||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
@@ -676,10 +676,10 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID := initAccount.Id
|
accountID := initAccount.Id
|
||||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
|
accountID, err = manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||||
require.NoError(t, err, "create init user failed")
|
require.NoError(t, err, "create init user failed")
|
||||||
// as initAccount was created without account id we have to take the id after account initialization
|
// as initAccount was created without account id we have to take the id after account initialization
|
||||||
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
|
// that happens inside the GetAccountIDByUserID where the id is getting generated
|
||||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||||
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
|
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
|
||||||
require.NoError(t, err, "get init account failed")
|
require.NoError(t, err, "get init account failed")
|
||||||
@@ -885,7 +885,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -894,7 +894,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
|||||||
|
|
||||||
userId := "test_user"
|
userId := "test_user"
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -903,14 +903,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
|
assert.True(t, exists, "expected to get existing account after creation using userid")
|
||||||
}
|
|
||||||
|
|
||||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
|
_, err = manager.GetAccountIDByUserID(context.Background(), "", "")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected an error when user and account IDs are empty")
|
t.Errorf("expected an error when user ID is empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1668,7 +1667,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||||
@@ -1683,7 +1682,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@@ -1695,7 +1694,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
@@ -1741,7 +1740,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@@ -1769,7 +1768,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
@@ -1789,7 +1788,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@@ -1801,7 +1800,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
@@ -1849,7 +1848,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
@@ -1860,9 +1859,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||||
|
|
||||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
|
||||||
require.NoError(t, err, "unable to get account by ID")
|
|
||||||
|
|
||||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||||
require.NoError(t, err, "unable to get account settings")
|
require.NoError(t, err, "unable to get account settings")
|
||||||
|
|
||||||
|
@@ -201,7 +201,7 @@ func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context,
|
|||||||
}
|
}
|
||||||
return "", status.Errorf(
|
return "", status.Errorf(
|
||||||
codes.Unimplemented,
|
codes.Unimplemented,
|
||||||
"method GetAccountIDByUserOrAccountID is not implemented",
|
"method GetAccountIDByUserID is not implemented",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -798,10 +798,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
|
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
acc, err := am.Store.GetAccount(context.Background(), accID)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
for _, id := range tc.expectedDeleted {
|
for _, id := range tc.expectedDeleted {
|
||||||
|
Reference in New Issue
Block a user