From 7d319d5a2ebca23bf6653425e96d5fd3158c4f09 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 2 Oct 2024 16:35:02 +0300 Subject: [PATCH] Refactor GetAccountIDByUserOrAccountID Signed-off-by: bcmmbaga --- management/server/account.go | 66 +++++++++++-------- management/server/account_test.go | 40 +++++------ management/server/mock_server/account_mock.go | 2 +- management/server/user_test.go | 5 +- 4 files changed, 59 insertions(+), 54 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 710b6f62f..aee386211 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -76,7 +76,7 @@ type AccountManager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) - GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) + GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) @@ -1260,37 +1260,31 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. -// If an accountID is provided, it checks if the account exists and returns it. -// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. +// GetAccountIDByUserID retrieves the account ID based on the userID provided. +// If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. // Returns the account ID or an error if none is found or created. -func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { - if accountID != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", accountID) - } - return accountID, nil +func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) { + if userID == "" { + return "", status.Errorf(status.NotFound, "no valid userID provided") } - if userID != "" { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) - if err != nil { - return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) - } + accountID, err := am.Store.GetAccountIDByUserID(userID) + if err != nil { + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + if err != nil { + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) + } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { - return "", err + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + return "", err + } + return account.Id, nil } - - return account.Id, nil + return "", err } - - return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") + return accountID, nil } func isNil(i idp.Manager) bool { @@ -1794,6 +1788,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } + if user.AccountID != accountID { + return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID) + } + if !user.IsServiceUser && claims.Invited { err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { @@ -1914,7 +1912,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + if claims.AccountId != "" { + exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId) + if err != nil { + return "", err + } + if !exists { + return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId) + } + return claims.AccountId, nil + } + return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } else if claims.AccountId != "" { userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err != nil { @@ -2227,7 +2235,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac routes := make(map[route.ID]*route.Route) setupKeys := map[string]*SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - users[userID] = NewOwnerUser(userID) + + owner := NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } diff --git a/management/server/account_test.go b/management/server/account_test.go index 303261bea..be9d557a5 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") initAccount, err := manager.Store.GetAccount(context.Background(), accountID) @@ -676,10 +676,10 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "unable to create account manager") accountID := initAccount.Id - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain) + accountID, err = manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated + // that happens inside the GetAccountIDByUserID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it initAccount, err = manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") @@ -885,7 +885,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } } -func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { +func TestAccountManager_GetAccountByUserID(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -894,7 +894,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "") if err != nil { t.Fatal(err) } @@ -903,14 +903,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { return } - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) - } + exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + assert.NoError(t, err) + assert.True(t, exists, "expected to get existing account after creation using userid") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), "", "") if err == nil { - t.Errorf("expected an error when user and account IDs are empty") + t.Errorf("expected an error when user ID is empty") } } @@ -1668,7 +1667,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) @@ -1683,7 +1682,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1695,7 +1694,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1741,7 +1740,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1769,7 +1768,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1789,7 +1788,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1801,7 +1800,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1849,7 +1848,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ @@ -1860,9 +1859,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - require.NoError(t, err, "unable to get account by ID") - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index df12ec1c4..abf4cc185 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -201,7 +201,7 @@ func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, } return "", status.Errorf( codes.Unimplemented, - "method GetAccountIDByUserOrAccountID is not implemented", + "method GetAccountIDByUserID is not implemented", ) } diff --git a/management/server/user_test.go b/management/server/user_test.go index e394ef840..8ca607d75 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -798,10 +798,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") - assert.NoError(t, err) - - acc, err := am.Store.GetAccount(context.Background(), accID) + acc, err := am.Store.GetAccount(context.Background(), account.Id) assert.NoError(t, err) for _, id := range tc.expectedDeleted {