From 9cb7336ef56aa3079b9a30f491aa7fd03f9e1fd6 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 6 Nov 2024 15:59:12 +0300 Subject: [PATCH] fix tests Signed-off-by: bcmmbaga --- client/testdata/store.sql | 3 + management/server/account.go | 11 - management/server/account_test.go | 307 +++++------ management/server/dns_test.go | 119 ++--- management/server/ephemeral_test.go | 104 ++-- management/server/group_test.go | 72 ++- management/server/http/groups_handler_test.go | 2 +- .../http/middleware/auth_middleware_test.go | 11 +- management/server/http/peers_handler_test.go | 146 ++--- management/server/nameserver_test.go | 120 ++--- management/server/peer_test.go | 129 +++-- management/server/posture_checks.go | 2 +- management/server/posture_checks_test.go | 198 ++++--- management/server/route_test.go | 344 +++++------- management/server/setupkey_test.go | 52 +- management/server/sql_store_test.go | 256 ++++++--- management/server/user.go | 2 +- management/server/user_test.go | 498 +++++++++--------- 18 files changed, 1244 insertions(+), 1132 deletions(-) diff --git a/client/testdata/store.sql b/client/testdata/store.sql index ed5395486..02ff3dc4d 100644 --- a/client/testdata/store.sql +++ b/client/testdata/store.sql @@ -28,9 +28,12 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); +INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); +INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); INSERT INTO installations VALUES(1,''); COMMIT; diff --git a/management/server/account.go b/management/server/account.go index 0005f1f83..248222ea4 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -817,15 +817,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap { return groupList } -func (a *Account) getTakenIPs() []net.IP { - var takenIps []net.IP - for _, existingPeer := range a.Peers { - takenIps = append(takenIps, existingPeer.IP) - } - - return takenIps -} - func (a *Account) getPeerDNSLabels() lookupMap { existingLabels := make(lookupMap) for _, peer := range a.Peers { @@ -1147,8 +1138,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) - am.updateAccountPeers(ctx, accountID) - return newSettings, nil } diff --git a/management/server/account_test.go b/management/server/account_test.go index d0238fe0a..0e2198e03 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -398,7 +398,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } for _, testCase := range tt { - account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io") + store := newStore(t) + + err := newAccountWithId(context.Background(), store, "account-1", userID, "netbird.io") + require.NoError(t, err, "failed to create account") + + account, err := store.GetAccount(context.Background(), "account-1") + require.NoError(t, err, "failed to get account") + account.UpdateSettings(&testCase.accountSettings) account.Network = network account.Peers = testCase.peers @@ -416,6 +423,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) + + store.Close(context.Background()) } } @@ -423,7 +432,15 @@ func TestNewAccount(t *testing.T) { domain := "netbird.io" userId := "account_creator" accountID := "account_id" - account := newAccountWithId(context.Background(), accountID, userId, domain) + + store := newStore(t) + defer store.Close(context.Background()) + + err := newAccountWithId(context.Background(), store, accountID, userId, domain) + require.NoError(t, err, "failed to create account") + + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "failed to get account") verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) } @@ -434,16 +451,16 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { return } - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - if account == nil { + if accountID == "" { t.Fatalf("expected to create an account for a user %s", userID) return } - account, err = manager.Store.GetAccountByUser(context.Background(), userID) + account, err := manager.Store.GetAccountByUser(context.Background(), userID) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID) return @@ -666,15 +683,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { userId := "user-id" domain := "test.domain" - _ = newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") - // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountIDByUserID where the id is getting generated - // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") @@ -690,44 +704,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "get account failed") + accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account groups") - require.Len(t, account.Groups, 1, "only ALL group should exists") + require.Len(t, accountGroups, 1, "only ALL group should exists") }) t.Run("JWT groups enabled without claim name", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true - err := manager.Store.SaveAccount(context.Background(), initAccount) - require.NoError(t, err, "save account failed") - require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings) + require.NoError(t, err, "failed to update account settings") + + accountIDs, err := manager.Store.GetAllAccountIDs(context.Background(), LockingStrengthShare) + require.NoError(t, err, "failed to get account ids") + require.Len(t, accountIDs, 1, "only one account should exist") accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "get account failed") + accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account groups") - require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") + require.Len(t, accountGroups, 1, "if group claim is not set no group added from JWT") }) t.Run("JWT groups enabled", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsClaimName = "idp-groups" - err := manager.Store.SaveAccount(context.Background(), initAccount) - require.NoError(t, err, "save account failed") - require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings) + require.NoError(t, err, "failed to update account settings") + + accountIDs, err := manager.Store.GetAllAccountIDs(context.Background(), LockingStrengthShare) + require.NoError(t, err, "failed to get account ids") + require.Len(t, accountIDs, 1, "only one account should exist") accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "get account failed") + exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to check account existence") + require.True(t, exists, "account should exist") - require.Len(t, account.Groups, 3, "groups should be added to the account") + accountGroups, err := manager.GetAllGroups(context.Background(), accountID, userId) + require.NoError(t, err, "failed to get account groups") + require.Len(t, accountGroups, 3, "groups should be added to the account") groupsByNames := map[string]*group.Group{} - for _, g := range account.Groups { + for _, g := range accountGroups { groupsByNames[g.Name] = g } @@ -745,60 +768,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { func TestAccountManager_GetAccountFromPAT(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") + err := newAccountWithId(context.Background(), store, "account_id", "testuser", "") + require.NoError(t, err, "failed to create account") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ - Id: "someUser", - PATs: map[string]*PersonalAccessToken{ - "tokenId": { - ID: "tokenId", - HashedToken: encodedHashedToken, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) + + userPAT := &PersonalAccessToken{ + ID: "tokenId", + UserID: "testuser", + HashedToken: encodedHashedToken, + CreatedAt: time.Now().UTC(), } + err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT) + require.NoError(t, err, "failed to save PAT") am := DefaultAccountManager{ Store: store, } - account, user, pat, err := am.GetAccountFromPAT(context.Background(), token) + user, pat, _, _, err := am.GetAccountInfoFromPAT(context.Background(), token) if err != nil { t.Fatalf("Error when getting Account from PAT: %s", err) } - assert.Equal(t, "account_id", account.Id) - assert.Equal(t, "someUser", user.Id) - assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) + assert.Equal(t, "account_id", user.AccountID) + assert.Equal(t, "testuser", user.Id) + assert.Equal(t, userPAT, pat) } func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") + err := newAccountWithId(context.Background(), store, "account_id", "testuser", "") + require.NoError(t, err, "failed to create account") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ - Id: "someUser", - PATs: map[string]*PersonalAccessToken{ - "tokenId": { - ID: "tokenId", - HashedToken: encodedHashedToken, - LastUsed: time.Time{}, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) + + userPAT := &PersonalAccessToken{ + ID: "tokenId", + UserID: "someUser", + HashedToken: encodedHashedToken, + LastUsed: time.Time{}, } + err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT) + require.NoError(t, err, "failed to save PAT") am := DefaultAccountManager{ Store: store, @@ -809,11 +825,10 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { t.Fatalf("Error when marking PAT used: %s", err) } - account, err = am.Store.GetAccount(context.Background(), "account_id") - if err != nil { - t.Fatalf("Error when getting account: %s", err) - } - assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero()) + userPAT, err = store.GetPATByID(context.Background(), LockingStrengthShare, userPAT.UserID, userPAT.ID) + require.NoError(t, err, "failed to get PAT") + + assert.True(t, !userPAT.LastUsed.IsZero()) } func TestAccountManager_PrivateAccount(t *testing.T) { @@ -824,15 +839,15 @@ func TestAccountManager_PrivateAccount(t *testing.T) { } userId := "test_user" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, "") if err != nil { t.Fatal(err) } - if account == nil { + if accountID == "" { t.Fatalf("expected to create an account for a user %s", userId) } - account, err = manager.Store.GetAccountByUser(context.Background(), userId) + account, err := manager.Store.GetAccountByUser(context.Background(), userId) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) } @@ -851,32 +866,22 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { userId := "test_user" domain := "hotmail.com" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain) - if err != nil { - t.Fatal(err) - } - if account == nil { - t.Fatalf("expected to create an account for a user %s", userId) - } + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain) + require.NoError(t, err, "failed to get or create account by user") + require.NotEmptyf(t, accountID, "expected to create an account for a user %s", userId) - if account != nil && account.Domain != domain { - t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) - } + accDomain, _, err := manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account domain and category") + require.Equal(t, domain, accDomain, "expected account domain to match") domain = "gmail.com" - account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain) - if err != nil { - t.Fatalf("got the following error while retrieving existing acc: %v", err) - } + accountID, err = manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain) + require.NoError(t, err, "failed to get or create account by user") - if account == nil { - t.Fatalf("expected to get an account for a user %s", userId) - } - - if account != nil && account.Domain != domain { - t.Errorf("updating domain. expected %s got %s", domain, account.Domain) - } + accDomain, _, err = manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account domain and category") + require.Equal(t, domain, accDomain, "expected account domain to match") } func TestAccountManager_GetAccountByUserID(t *testing.T) { @@ -908,12 +913,11 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { } func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { - account := newAccountWithId(context.Background(), accountID, userID, domain) - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain) if err != nil { return nil, err } - return account, nil + return am.Store.GetAccount(context.Background(), accountID) } func TestAccountManager_GetAccount(t *testing.T) { @@ -1056,23 +1060,18 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { return } - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud") - if err != nil { - t.Fatal(err) - } + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "netbird.cloud") + require.NoError(t, err, "failed to get or create account by user") - serial := account.Network.CurrentSerial() // should be 0 + network, err := manager.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account network") - if account.Network.Serial != 0 { - t.Errorf("expecting account network to have an initial Serial=0") - return - } + serial := network.CurrentSerial() // should be 0 + require.Equal(t, 0, int(serial), "expected account network to have an initial Serial=0") key, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Fatal(err) - return - } + require.NoError(t, err, "failed to generate private key") + expectedPeerKey := key.PublicKey().String() expectedUserID := userID @@ -1080,16 +1079,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) - if err != nil { - t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy) - return - } + require.NoError(t, err, "failed to add peer") - account, err = manager.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - return - } + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "failed to get account") if peer.Key != expectedPeerKey { t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key) @@ -1215,10 +1208,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + policyID := xid.New().String() policy := Policy{ - Enabled: true, + ID: policyID, + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { + ID: "rule", + PolicyID: policyID, Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -1252,19 +1250,25 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { manager, account, peer1, _, peer3 := setupNetworkMapTest(t) group := group.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer3.ID}, + ID: "groupA", + AccountID: account.Id, + Name: "GroupA", + Peers: []string{peer1.ID, peer3.ID}, } if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } + policyID := xid.New().String() policy := Policy{ - Enabled: true, + ID: policyID, + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { + ID: "rule", + PolicyID: policyID, Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -1305,19 +1309,24 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - group := group.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + ID: "groupA", + AccountID: account.Id, + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, } + err := manager.SaveGroup(context.Background(), account.Id, userID, &group) + require.NoError(t, err, "failed to save group") + policyID := xid.New().String() policy := Policy{ - Enabled: true, + ID: policyID, + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { + ID: "rule", + PolicyID: policyID, Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -1327,6 +1336,9 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { }, } + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { t.Errorf("delete default rule: %v", err) return @@ -1355,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } - if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + if err := manager.DeleteGroup(context.Background(), account.Id, userID, group.ID); err != nil { t.Errorf("delete group: %v", err) return } @@ -1754,12 +1766,6 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Hour, - PeerLoginExpirationEnabled: true, - }) - require.NoError(t, err, "expecting to update account settings successfully but got error") - wg := &sync.WaitGroup{} wg.Add(2) manager.peerLoginExpiry = &MockScheduler{ @@ -1865,10 +1871,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }, } // enabling PeerLoginExpirationEnabled should trigger the expiration job - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Hour, - PeerLoginExpirationEnabled: true, - }) + settings, err := manager.GetAccountSettings(context.Background(), accountID, userID) + require.NoError(t, err, "failed to get account settings") + + settings.PeerLoginExpirationEnabled = true + settings.PeerLoginExpiration = time.Hour + settings, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings) require.NoError(t, err, "expecting to update account settings successfully but got error") failed := waitTimeout(wg, time.Second) @@ -1878,10 +1886,8 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test wg.Add(1) // disabling PeerLoginExpirationEnabled should trigger cancel - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Hour, - PeerLoginExpirationEnabled: false, - }) + settings.PeerLoginExpirationEnabled = false + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings) require.NoError(t, err, "expecting to update account settings successfully but got error") failed = waitTimeout(wg, time.Second) if failed { @@ -1896,30 +1902,29 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Hour, - PeerLoginExpirationEnabled: false, - }) + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "unable to get account settings") + + settings.PeerLoginExpirationEnabled = false + settings.PeerLoginExpiration = time.Hour + updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, settings) require.NoError(t, err, "expecting to update account settings successfully but got error") assert.False(t, updatedSettings.PeerLoginExpirationEnabled) assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + settings, err = manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") - assert.False(t, settings.PeerLoginExpirationEnabled) assert.Equal(t, settings.PeerLoginExpiration, time.Hour) - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Second, - PeerLoginExpirationEnabled: false, - }) + settings.PeerLoginExpirationEnabled = false + settings.PeerLoginExpiration = time.Second + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Hour * 24 * 181, - PeerLoginExpirationEnabled: false, - }) + settings.PeerLoginExpirationEnabled = false + settings.PeerLoginExpiration = time.Hour * 24 * 181 + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 8a66da96c..b50a67594 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -39,12 +39,12 @@ func TestGetDNSSettings(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestDNSAccount(t, am) + accountID, err := initTestDNSAccount(t, am) if err != nil { t.Fatal("failed to init testing account") } - dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) + dnsSettings, err := am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID) if err != nil { t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) } @@ -53,16 +53,12 @@ func TestGetDNSSettings(t *testing.T) { t.Fatal("DNS settings for new accounts shouldn't return nil") } - account.DNSSettings = DNSSettings{ + err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &DNSSettings{ DisabledManagementGroups: []string{group1ID}, - } + }) + require.NoError(t, err, "failed to update DNS settings") - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("failed to save testing account with new DNS settings") - } - - dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) + dnsSettings, err = am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID) if err != nil { t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) } @@ -71,7 +67,7 @@ func TestGetDNSSettings(t *testing.T) { t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) } - _, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID) + _, err = am.GetDNSSettings(context.Background(), accountID, dnsRegularUserID) if err == nil { t.Errorf("An error should be returned when getting the DNS settings with a regular user") } @@ -126,12 +122,12 @@ func TestSaveDNSSettings(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestDNSAccount(t, am) + accountID, err := initTestDNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) + err = am.SaveDNSSettings(context.Background(), accountID, testCase.userID, testCase.inputSettings) if err != nil { if testCase.shouldFail { return @@ -139,7 +135,7 @@ func TestSaveDNSSettings(t *testing.T) { t.Error(err) } - updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id) + updatedAccount, err := am.Store.GetAccount(context.Background(), accountID) if err != nil { t.Errorf("should be able to retrieve updated account, got err: %s", err) } @@ -158,17 +154,17 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestDNSAccount(t, am) + accountID, err := initTestDNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - peer1, err := account.FindPeerByPubKey(dnsPeer1Key) + peer1, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer1Key) if err != nil { t.Error("failed to init testing account") } - peer2, err := account.FindPeerByPubKey(dnsPeer2Key) + peer2, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer2Key) if err != nil { t.Error("failed to init testing account") } @@ -179,11 +175,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) { require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group") - dnsSettings := account.DNSSettings.Copy() + accountDNSSettings, err := am.Store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account DNS settings") + + dnsSettings := accountDNSSettings.Copy() dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) - account.DNSSettings = dnsSettings - err = am.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &dnsSettings) + require.NoError(t, err, "failed to update DNS settings") updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) require.NoError(t, err) @@ -222,7 +220,7 @@ func createDNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: dnsPeer1Key, @@ -257,64 +255,65 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro domain := "example.com" - account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) - - account.Users[dnsRegularUserID] = &User{ - Id: dnsRegularUserID, - Role: UserRoleUser, + err := newAccountWithId(context.Background(), am.Store, dnsAccountID, dnsAdminUserID, domain) + if err != nil { + return "", err } - err := am.Store.SaveAccount(context.Background(), account) + err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: dnsRegularUserID, + AccountID: dnsAccountID, + Role: UserRoleUser, + }) if err != nil { - return nil, err + return "", err } savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) if err != nil { - return nil, err + return "", err } _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) if err != nil { - return nil, err + return "", err } - account, err = am.Store.GetAccount(context.Background(), account.Id) + peer1, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer1.Key) if err != nil { - return nil, err + return "", err } - peer1, err = account.FindPeerByPubKey(peer1.Key) + _, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer2.Key) if err != nil { - return nil, err + return "", err } - _, err = account.FindPeerByPubKey(peer2.Key) + err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{ + { + ID: dnsGroup1ID, + AccountID: dnsAccountID, + Peers: []string{peer1.ID}, + Name: dnsGroup1ID, + }, + { + ID: dnsGroup2ID, + AccountID: dnsAccountID, + Name: dnsGroup2ID, + }, + }) if err != nil { - return nil, err + return "", err } - newGroup1 := &group.Group{ - ID: dnsGroup1ID, - Peers: []string{peer1.ID}, - Name: dnsGroup1ID, - } - - newGroup2 := &group.Group{ - ID: dnsGroup2ID, - Name: dnsGroup2ID, - } - - account.Groups[newGroup1.ID] = newGroup1 - account.Groups[newGroup2.ID] = newGroup2 - - allGroup, err := account.GetGroupAll() + allGroup, err := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, dnsAccountID, "All") if err != nil { - return nil, err + return "", err } - account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{ - ID: dnsNSGroup1, - Name: "ns-group-1", + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &dns.NameServerGroup{ + ID: dnsNSGroup1, + AccountID: dnsAccountID, + Name: "ns-group-1", NameServers: []dns.NameServer{{ IP: netip.MustParseAddr(savedPeer1.IP.String()), NSType: dns.UDPNameServerType, @@ -323,14 +322,12 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro Primary: true, Enabled: true, Groups: []string{allGroup.ID}, - } - - err = am.Store.SaveAccount(context.Background(), account) + }) if err != nil { - return nil, err + return "", err } - return am.Store.GetAccount(context.Background(), account.Id) + return dnsAccountID, nil } func generateTestData(size int) nbdns.Config { diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 1390352a5..c1f79aad3 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -7,35 +7,35 @@ import ( "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/status" + "github.com/stretchr/testify/require" ) type MockStore struct { Store - account *Account + accountID string } -func (s *MockStore) GetAllAccounts(_ context.Context) []*Account { - return []*Account{s.account} -} +//func (s *MockStore) GetAllAccounts(_ context.Context) []*Account { +// return []*Account{s.account} +//} -func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) { - _, ok := s.account.Peers[peerId] - if ok { - return s.account, nil - } - - return nil, status.NewPeerNotFoundError(peerId) -} +//func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) { +// +// _, ok := s.account.Peers[peerId] +// if ok { +// return s.account, nil +// } +// +// return nil, status.NewPeerNotFoundError(peerId) +//} type MocAccountManager struct { AccountManager store *MockStore } -func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { - delete(a.store.account.Peers, peerID) - return nil //nolint:nil +func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, _ string) error { + return a.store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID) } func TestNewManager(t *testing.T) { @@ -44,23 +44,26 @@ func TestNewManager(t *testing.T) { return startTime } - store := &MockStore{} + store := &MockStore{ + Store: newStore(t), + } am := MocAccountManager{ store: store, } numberOfPeers := 5 numberOfEphemeralPeers := 3 - seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + require.NoError(t, err, "failed to seed peers") mgr := NewEphemeralManager(store, am) mgr.loadEphemeralPeers(context.Background()) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) - if len(store.account.Peers) != numberOfPeers { - t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers)) - } + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") + require.Equal(t, numberOfPeers, len(peers), "failed to cleanup ephemeral peers") } func TestNewManagerPeerConnected(t *testing.T) { @@ -76,19 +79,23 @@ func TestNewManagerPeerConnected(t *testing.T) { numberOfPeers := 5 numberOfEphemeralPeers := 3 - seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + require.NoError(t, err, "failed to seed peers") mgr := NewEphemeralManager(store, am) mgr.loadEphemeralPeers(context.Background()) - mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) + + peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0") + require.NoError(t, err, "failed to get peer") + + mgr.OnPeerConnected(context.Background(), peer) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) - expected := numberOfPeers + 1 - if len(store.account.Peers) != expected { - t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) - } + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") + require.Equal(t, numberOfPeers+1, len(peers), "failed to cleanup ephemeral peers") } func TestNewManagerPeerDisconnected(t *testing.T) { @@ -104,43 +111,64 @@ func TestNewManagerPeerDisconnected(t *testing.T) { numberOfPeers := 5 numberOfEphemeralPeers := 3 - seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + require.NoError(t, err, "failed to seed peers") mgr := NewEphemeralManager(store, am) mgr.loadEphemeralPeers(context.Background()) - for _, v := range store.account.Peers { - mgr.OnPeerConnected(context.Background(), v) + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") + for _, v := range peers { + mgr.OnPeerConnected(context.Background(), v) } - mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) + + peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0") + require.NoError(t, err, "failed to get peer") + mgr.OnPeerDisconnected(context.Background(), peer) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) + peers, err = store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") expected := numberOfPeers + numberOfEphemeralPeers - 1 - if len(store.account.Peers) != expected { - t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) - } + require.Equal(t, expected, len(peers), "failed to cleanup ephemeral peers") } -func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { - store.account = newAccountWithId(context.Background(), "my account", "", "") +func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) error { + accountID := "my account" + err := newAccountWithId(context.Background(), store, accountID, "", "") + if err != nil { + return err + } + store.accountID = accountID for i := 0; i < numberOfPeers; i++ { peerId := fmt.Sprintf("peer_%d", i) p := &nbpeer.Peer{ ID: peerId, + AccountID: accountID, Ephemeral: false, } - store.account.Peers[p.ID] = p + err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, p) + if err != nil { + return err + } } for i := 0; i < numberOfEphemeralPeers; i++ { peerId := fmt.Sprintf("ephemeral_peer_%d", i) p := &nbpeer.Peer{ ID: peerId, + AccountID: accountID, Ephemeral: true, } - store.account.Peers[p.ID] = p + err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, p) + if err != nil { + return err + } } + + return nil } diff --git a/management/server/group_test.go b/management/server/group_test.go index 89184e819..9012832e5 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -328,25 +328,30 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A } routeResource := &route.Route{ - ID: "example route", - Groups: []string{groupForRoute.ID}, + ID: "example route", + AccountID: accountID, + Groups: []string{groupForRoute.ID}, } routePeerGroupResource := &route.Route{ ID: "example route with peer groups", + AccountID: accountID, PeerGroups: []string{groupForRoute2.ID}, } nameServerGroup := &nbdns.NameServerGroup{ - ID: "example name server group", - Groups: []string{groupForNameServerGroups.ID}, + ID: "example name server group", + AccountID: accountID, + Groups: []string{groupForNameServerGroups.ID}, } policy := &Policy{ - ID: "example policy", + ID: "example policy", + AccountID: accountID, Rules: []*PolicyRule{ { ID: "example policy rule", + PolicyID: "example policy", Destinations: []string{groupForPolicies.ID}, }, }, @@ -354,35 +359,60 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A setupKey := &SetupKey{ Id: "example setup key", + AccountID: accountID, AutoGroups: []string{groupForSetupKeys.ID}, } user := &User{ Id: "example user", + AccountID: accountID, AutoGroups: []string{groupForUsers.ID}, } - account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) - account.Routes[routeResource.ID] = routeResource - account.Routes[routePeerGroupResource.ID] = routePeerGroupResource - account.NameServerGroups[nameServerGroup.ID] = nameServerGroup - account.Policies = append(account.Policies, policy) - account.SetupKeys[setupKey.Id] = setupKey - account.Users[user.Id] = user - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain) if err != nil { return nil, nil, err } - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routeResource) + if err != nil { + return nil, nil, err + } - acc, err := am.Store.GetAccount(context.Background(), account.Id) + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routePeerGroupResource) + if err != nil { + return nil, nil, err + } + + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameServerGroup) + if err != nil { + return nil, nil, err + } + + err = am.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + if err != nil { + return nil, nil, err + } + + err = am.Store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + if err != nil { + return nil, nil, err + } + + err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, user) + if err != nil { + return nil, nil, err + } + + err = am.SaveGroups(context.Background(), accountID, groupAdminUserID, []*nbgroup.Group{ + groupForRoute, groupForRoute2, groupForNameServerGroups, groupForPolicies, + groupForSetupKeys, groupForUsers, groupForUsers, groupForIntegration, + }) + if err != nil { + return nil, nil, err + } + + acc, err := am.Store.GetAccount(context.Background(), accountID) if err != nil { return nil, nil, err } diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 7f3c81f18..8fd68c384 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -68,7 +68,7 @@ func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { return nil, fmt.Errorf("unknown group name") }, - GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + GetUserPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return maps.Values(TestPeers), nil }, DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index fdfb0ea24..0e0872d31 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -33,7 +33,8 @@ var testAccount = &server.Account{ Domain: domain, Users: map[string]*server.User{ userID: { - Id: userID, + Id: userID, + AccountID: accountID, PATs: map[string]*server.PersonalAccessToken{ tokenID: { ID: tokenID, @@ -49,11 +50,11 @@ var testAccount = &server.Account{ }, } -func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) { if token == PAT { - return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil + return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil } - return nil, nil, nil, fmt.Errorf("PAT invalid") + return nil, nil, "", "", fmt.Errorf("PAT invalid") } func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) { @@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { ) authMiddleware := NewAuthMiddleware( - mockGetAccountFromPAT, + mockGetAccountInfoFromPAT, mockValidateAndParseToken, mockMarkPATUsed, mockCheckUserAccessByJWTGroups, diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index dd49c03b8..3fa3ea2f9 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -39,6 +39,68 @@ const ( ) func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer.Copy() + } + + policy := &server.Policy{ + ID: "policy", + AccountID: "test_id", + Name: "policy", + Enabled: true, + Rules: []*server.PolicyRule{ + { + ID: "rule", + Name: "rule", + Enabled: true, + Action: "accept", + Destinations: []string{"group1"}, + Sources: []string{"group1"}, + Bidirectional: true, + Protocol: "all", + Ports: []string{"80"}, + }, + }, + } + + srvUser := server.NewRegularUser(serviceUser) + srvUser.IsServiceUser = true + + account := &server.Account{ + Id: "test_id", + Domain: "hotmail.com", + Peers: peersMap, + Users: map[string]*server.User{ + adminUser: server.NewAdminUser(adminUser), + regularUser: server.NewRegularUser(regularUser), + serviceUser: srvUser, + }, + Groups: map[string]*nbgroup.Group{ + "group1": { + ID: "group1", + AccountID: "test_id", + Name: "group1", + Issued: "api", + Peers: maps.Keys(peersMap), + }, + }, + Settings: &server.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour, + }, + Policies: []*server.Policy{policy}, + Network: &server.Network{ + Identifier: "ciclqisab2ss43jdn8q0", + Net: net.IPNet{ + IP: net.ParseIP("100.67.0.0"), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + Serial: 51, + }, + } + return &PeersHandler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -64,77 +126,37 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { } return p, nil }, - GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + GetUserPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, + ListPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + return peers, nil + }, + GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { + peersID := make([]string, len(peers)) + for _, peer := range peers { + peersID = append(peersID, peer.ID) + } + return []*nbgroup.Group{ + { + ID: "group1", + AccountID: accountID, + Name: "group1", + Issued: "api", + Peers: peersID, + }, + }, nil + }, GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, + GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) { + return account, nil + }, GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { - peersMap := make(map[string]*nbpeer.Peer) - for _, peer := range peers { - peersMap[peer.ID] = peer.Copy() - } - - policy := &server.Policy{ - ID: "policy", - AccountID: accountID, - Name: "policy", - Enabled: true, - Rules: []*server.PolicyRule{ - { - ID: "rule", - Name: "rule", - Enabled: true, - Action: "accept", - Destinations: []string{"group1"}, - Sources: []string{"group1"}, - Bidirectional: true, - Protocol: "all", - Ports: []string{"80"}, - }, - }, - } - - srvUser := server.NewRegularUser(serviceUser) - srvUser.IsServiceUser = true - - account := &server.Account{ - Id: accountID, - Domain: "hotmail.com", - Peers: peersMap, - Users: map[string]*server.User{ - adminUser: server.NewAdminUser(adminUser), - regularUser: server.NewRegularUser(regularUser), - serviceUser: srvUser, - }, - Groups: map[string]*nbgroup.Group{ - "group1": { - ID: "group1", - AccountID: accountID, - Name: "group1", - Issued: "api", - Peers: maps.Keys(peersMap), - }, - }, - Settings: &server.Settings{ - PeerLoginExpirationEnabled: true, - PeerLoginExpiration: time.Hour, - }, - Policies: []*server.Policy{policy}, - Network: &server.Network{ - Identifier: "ciclqisab2ss43jdn8q0", - Net: net.IPNet{ - IP: net.ParseIP("100.67.0.0"), - Mask: net.IPv4Mask(255, 255, 0, 0), - }, - Serial: 51, - }, - } - return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 846dbf023..e03833535 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/management/server/status" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -381,14 +382,14 @@ func TestCreateNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } outNSGroup, err := am.CreateNameServerGroup( context.Background(), - account.Id, + accountID, testCase.inputArgs.name, testCase.inputArgs.description, testCase.inputArgs.nameServers, @@ -609,20 +610,16 @@ func TestSaveNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup - - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("account should be saved") - } + testCase.existingNSGroup.AccountID = accountID + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testCase.existingNSGroup) + require.NoError(t, err, "failed to save existing nameserver group") var nsGroupToSave *nbdns.NameServerGroup - if !testCase.skipCopying { nsGroupToSave = testCase.existingNSGroup.Copy() @@ -651,7 +648,7 @@ func TestSaveNameServerGroup(t *testing.T) { } } - err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave) + err = am.SaveNameServerGroup(context.Background(), accountID, userID, nsGroupToSave) testCase.errFunc(t, err) @@ -659,13 +656,8 @@ func TestSaveNameServerGroup(t *testing.T) { return } - account, err = am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - } - - savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID] - require.True(t, saved) + savedNSGroup, err := am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testCase.expectedNSGroup.ID) + require.NoError(t, err, "failed to get saved nameserver group") if !testCase.expectedNSGroup.IsEqual(savedNSGroup) { t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup) @@ -703,32 +695,25 @@ func TestDeleteNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - account.NameServerGroups[testingNSGroup.ID] = testingNSGroup + testingNSGroup.AccountID = accountID + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testingNSGroup) + require.NoError(t, err, "failed to save nameserver group") - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("failed to save account") - } - - err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID) + err = am.DeleteNameServerGroup(context.Background(), accountID, testingNSGroup.ID, userID) if err != nil { t.Error("deleting nameserver group failed with error: ", err) } - savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Error("failed to retrieve saved account with error: ", err) - } - - _, found := savedAccount.NameServerGroups[testingNSGroup.ID] - if found { - t.Error("nameserver group shouldn't be found after delete") - } + _, err = am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testingNSGroup.ID) + require.NotNil(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok, "error should be a status error") + assert.Equal(t, status.NotFound, sErr.Type(), "nameserver group shouldn't be found after delete") } func TestGetNameServerGroup(t *testing.T) { @@ -738,12 +723,12 @@ func TestGetNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID) + foundGroup, err := am.GetNameServerGroup(context.Background(), accountID, testUserID, existingNSGroupID) if err != nil { t.Error("getting existing nameserver group failed with error: ", err) } @@ -752,7 +737,7 @@ func TestGetNameServerGroup(t *testing.T) { t.Error("got a nil group while getting nameserver group with ID") } - _, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing") + _, err = am.GetNameServerGroup(context.Background(), accountID, testUserID, "not existing") if err == nil { t.Error("getting not existing nameserver group should return error, got nil") } @@ -784,8 +769,12 @@ func createNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) { t.Helper() + accountID := "testingAcc" + userID := testUserID + domain := "example.com" + peer1 := &nbpeer.Peer{ Key: nsGroupPeer1Key, Name: "test-host1@netbird.io", @@ -816,6 +805,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error } existingNSGroup := nbdns.NameServerGroup{ ID: existingNSGroupID, + AccountID: accountID, Name: existingNSGroupName, Description: "", NameServers: []nbdns.NameServer{ @@ -834,42 +824,42 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error Enabled: true, } - accountID := "testingAcc" - userID := testUserID - domain := "example.com" - - account := newAccountWithId(context.Background(), accountID, userID, domain) - - account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - - newGroup1 := &nbgroup.Group{ - ID: group1ID, - Name: group1ID, - } - - newGroup2 := &nbgroup.Group{ - ID: group2ID, - Name: group2ID, - } - - account.Groups[newGroup1.ID] = newGroup1 - account.Groups[newGroup2.ID] = newGroup2 - - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain) if err != nil { - return nil, err + return "", err + } + + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &existingNSGroup) + if err != nil { + return "", err + } + + err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*nbgroup.Group{ + { + ID: group1ID, + AccountID: accountID, + Name: group1ID, + }, + { + ID: group2ID, + AccountID: accountID, + Name: group2ID, + }, + }) + if err != nil { + return "", err } _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) if err != nil { - return nil, err + return "", err } _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) if err != nil { - return nil, err + return "", err } - return account, nil + return accountID, nil } func TestValidateDomain(t *testing.T) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index b48e94273..573e57a21 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -467,21 +467,25 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ - Id: someUser, - Role: UserRoleUser, - } - account.Settings.RegularUsersViewBlocked = false + err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "") + require.NoError(t, err, "failed to create account") - err = manager.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) - return - } + err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: someUser, + AccountID: accountID, + Role: UserRoleUser, + }) + require.NoError(t, err, "failed to create user") + + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account settings") + + settings.RegularUsersViewBlocked = false + err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings) + require.NoError(t, err, "failed to save account settings") // two peers one added by a regular user and one with a setup key - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) + setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) if err != nil { t.Fatal("error creating setup key") return @@ -535,7 +539,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { assert.NotNil(t, peer) // delete the all-to-all policy so that user's peer1 has no access to peer2 - for _, policy := range account.Policies { + accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account policies") + + for _, policy := range accountPolicies { err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) if err != nil { t.Fatal(err) @@ -654,21 +661,33 @@ func TestDefaultAccountManager_GetUserPeers(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + + err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "") + require.NoError(t, err, "failed to create account") + + err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: someUser, + AccountID: accountID, Role: testCase.role, IsServiceUser: testCase.isServiceUser, - } - account.Policies = []*Policy{} - account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings + }) + require.NoError(t, err, "failed to create user") - err = manager.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) - return + accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account policies") + + for _, policy := range accountPolicies { + err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) + require.NoError(t, err, "failed to delete policy") } + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account settings") + + settings.RegularUsersViewBlocked = testCase.limitedViewSettings + err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings) + require.NoError(t, err, "failed to save account settings") + peerKey1, err := wgtypes.GeneratePrivateKey() if err != nil { t.Fatal(err) @@ -724,10 +743,18 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou adminUser := "account_creator" regularUser := "regular_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[regularUser] = &User{ - Id: regularUser, - Role: UserRoleUser, + err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "") + if err != nil { + return nil, "", "", err + } + + err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: regularUser, + AccountID: accountID, + Role: UserRoleUser, + }) + if err != nil { + return nil, "", "", err } // Create peers @@ -741,31 +768,40 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou Status: &nbpeer.PeerStatus{}, UserID: regularUser, } - account.Peers[peer.ID] = peer + err = manager.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer) + if err != nil { + return nil, "", "", err + } } // Create groups and policies - account.Policies = make([]*Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) group := &nbgroup.Group{ - ID: groupID, - Name: fmt.Sprintf("Group %d", i), + ID: groupID, + AccountID: accountID, + Name: fmt.Sprintf("Group %d", i), } for j := 0; j < peers/groups; j++ { peerIndex := i*(peers/groups) + j group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) } - account.Groups[groupID] = group + + err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + return nil, "", "", err + } // Create a policy for this group policy := &Policy{ - ID: fmt.Sprintf("policy-%d", i), - Name: fmt.Sprintf("Policy for Group %d", i), - Enabled: true, + ID: fmt.Sprintf("policy-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Policy for Group %d", i), + Enabled: true, Rules: []*PolicyRule{ { ID: fmt.Sprintf("rule-%d", i), + PolicyID: fmt.Sprintf("policy-%d", i), Name: fmt.Sprintf("Rule for Group %d", i), Enabled: true, Sources: []string{groupID}, @@ -776,22 +812,23 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou }, }, } - account.Policies = append(account.Policies, policy) + + err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + if err != nil { + return nil, "", "", err + } } - account.PostureChecks = []*posture.Checks{ - { - ID: "PostureChecksAll", - Name: "All", - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.0.1", - }, + err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, &posture.Checks{ + ID: "PostureChecksAll", + AccountID: accountID, + Name: "All", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", }, }, - } - - err = manager.Store.SaveAccount(context.Background(), account) + }) if err != nil { return nil, "", "", err } diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 48065d5ad..2c1378944 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -45,7 +45,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } if err = am.validatePostureChecks(ctx, accountID, postureChecks); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) + return status.Errorf(status.InvalidArgument, err.Error()) //nolint } updateAccountPeers, err := am.arePostureCheckChangesAffectPeers(ctx, accountID, postureChecks.ID, isUpdate) diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 6775583fd..c4a6525c1 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -7,6 +7,7 @@ import ( "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/group" @@ -26,22 +27,22 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestPostureChecksAccount(am) + accountID, err := initTestPostureChecksAccount(am) if err != nil { t.Error("failed to init testing account") } t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}, false) + err := am.SavePostureChecks(context.Background(), accountID, regularUserID, &posture.Checks{}, false) assert.Error(t, err) // regular users cannot list check - _, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID) + _, err = am.ListPostureChecks(context.Background(), accountID, regularUserID) assert.Error(t, err) // should be possible to create posture check with uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ + err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{ ID: postureCheckID, Name: postureCheckName, Checks: posture.ChecksDefinition{ @@ -53,12 +54,12 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.NoError(t, err) // admin users can list check - checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID) + checks, err := am.ListPostureChecks(context.Background(), accountID, adminUserID) assert.NoError(t, err) assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ + err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{ ID: "new-id", Name: postureCheckName, Checks: posture.ChecksDefinition{ @@ -74,7 +75,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // admins can update posture checks - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ + err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{ ID: postureCheckID, Name: postureCheckName, Checks: posture.ChecksDefinition{ @@ -86,41 +87,44 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) + err = am.DeletePostureChecks(context.Background(), accountID, postureCheckID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) + err = am.DeletePostureChecks(context.Background(), accountID, postureCheckID, adminUserID) assert.NoError(t, err) - checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) + checks, err = am.ListPostureChecks(context.Background(), accountID, adminUserID) assert.NoError(t, err) assert.Len(t, checks, 0) }) } -func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { +func initTestPostureChecksAccount(am *DefaultAccountManager) (string, error) { accountID := "testingAccount" domain := "example.com" - admin := &User{ - Id: adminUserID, - Role: UserRoleAdmin, - } - user := &User{ - Id: regularUserID, - Role: UserRoleUser, - } - - account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) - account.Users[admin.Id] = admin - account.Users[user.Id] = user - - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain) if err != nil { - return nil, err + return "", err } - return am.Store.GetAccount(context.Background(), account.Id) + err = am.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{ + { + Id: adminUserID, + AccountID: accountID, + Role: UserRoleAdmin, + }, + { + Id: regularUserID, + AccountID: accountID, + Role: UserRoleUser, + }, + }) + if err != nil { + return "", err + } + + return accountID, nil } func TestPostureCheckAccountPeersUpdate(t *testing.T) { @@ -169,7 +173,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, false) assert.NoError(t, err) select { @@ -192,7 +196,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true) assert.NoError(t, err) select { @@ -255,7 +259,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true) assert.NoError(t, err) select { @@ -303,7 +307,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, false) assert.NoError(t, err) // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update @@ -337,7 +341,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true) assert.NoError(t, err) select { @@ -384,7 +388,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true) assert.NoError(t, err) select { @@ -429,7 +433,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true) assert.NoError(t, err) select { @@ -441,79 +445,113 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } func TestArePostureCheckChangesAffectingPeers(t *testing.T) { - account := &Account{ - Policies: []*Policy{ - { - ID: "policyA", - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - }, - }, - SourcePostureChecks: []string{"checkA"}, - }, - }, - Groups: map[string]*group.Group{ - "groupA": { - ID: "groupA", - Peers: []string{"peer1"}, - }, - "groupB": { - ID: "groupB", - Peers: []string{}, - }, - }, - PostureChecks: []*posture.Checks{ - { - ID: "checkA", - }, - { - ID: "checkB", - }, - }, + manager, err := createManager(t) + require.NoError(t, err, "failed to create account manager") + + accountID, err := initTestPostureChecksAccount(manager) + require.NoError(t, err, "failed to init testing account") + + groupA := &group.Group{ + ID: "groupA", + AccountID: accountID, + Peers: []string{"peer1"}, } + groupB := &group.Group{ + ID: "groupA", + AccountID: accountID, + Peers: []string{}, + } + err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + require.NoError(t, err, "failed to save groups") + + policy := &Policy{ + ID: "policyA", + AccountID: accountID, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + }, + }, + SourcePostureChecks: []string{"checkA"}, + } + err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err, "failed to save policy") + + postureCheckA := &posture.Checks{ + ID: "checkA", + AccountID: accountID, + } + err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureCheckA) + require.NoError(t, err, "failed to save postureCheckA") + + postureCheckB := &posture.Checks{ + ID: "checkB", + AccountID: accountID, + } + err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureCheckB) + require.NoError(t, err, "failed to save postureCheckB") + t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkB", true) + result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkB", true) + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check does not exist", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "unknown", false) + result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "unknown", false) + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupB"} - account.Policies[0].Rules[0].Destinations = []string{"groupA"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupB"} + policy.Rules[0].Destinations = []string{"groupA"} + err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err, "failed to update policy") + + result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupA"} - account.Policies[0].Rules[0].Destinations = []string{"groupB"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupA"} + policy.Rules[0].Destinations = []string{"groupB"} + err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err, "failed to update policy") + + result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} - account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"nonExistentGroup"} + policy.Rules[0].Destinations = []string{"nonExistentGroup"} + err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err, "failed to update policy") + + result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true) + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { - account.Groups["groupA"].Peers = []string{} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + groupA.Peers = []string{} + err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + require.NoError(t, err, "failed to save groups") + + result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true) + require.NoError(t, err) assert.False(t, result) }) } diff --git a/management/server/route_test.go b/management/server/route_test.go index 5c848f68c..870ed5997 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -5,12 +5,15 @@ import ( "fmt" "net" "net/netip" + "strings" "testing" "time" + "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" @@ -427,21 +430,22 @@ func TestCreateRoute(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Errorf("failed to init testing account: %s", err) } if testCase.createInitRoute { - groupAll, errInit := account.GetGroupAll() + groupAll, errInit := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) + + _, errInit = am.CreateRoute(context.Background(), accountID, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), accountID, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) require.NoError(t, errInit) } - outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) + outRoute, err := am.CreateRoute(context.Background(), accountID, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) testCase.errFunc(t, err) @@ -917,14 +921,15 @@ func TestSaveRoute(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Error("failed to init testing account") } if testCase.createInitRoute { - account.Routes["initRoute"] = &route.Route{ + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, &route.Route{ ID: "initRoute", + AccountID: accountID, Network: existingNetwork, NetID: existingRouteID, NetworkType: route.IPv4Network, @@ -934,18 +939,14 @@ func TestSaveRoute(t *testing.T) { Metric: 9999, Enabled: true, Groups: []string{routeGroup1}, - } + }) + require.NoError(t, err, "failed to save init route") } - account.Routes[testCase.existingRoute.ID] = testCase.existingRoute - - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("account should be saved") - } + err = am.SaveRoute(context.Background(), accountID, userID, testCase.existingRoute) + require.NoError(t, err, "failed to save existing route") var routeToSave *route.Route - if !testCase.skipCopying { routeToSave = testCase.existingRoute.Copy() if testCase.newPeer != nil { @@ -977,21 +978,15 @@ func TestSaveRoute(t *testing.T) { } } - err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave) - + err = am.SaveRoute(context.Background(), accountID, userID, routeToSave) testCase.errFunc(t, err) if !testCase.shouldCreate { return } - account, err = am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - } - - savedRoute, saved := account.Routes[testCase.expectedRoute.ID] - require.True(t, saved) + savedRoute, err := am.Store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(testCase.expectedRoute.ID)) + require.NoError(t, err, "failed to retrieve saved route") if !testCase.expectedRoute.IsEqual(savedRoute) { t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute) @@ -1019,32 +1014,26 @@ func TestDeleteRoute(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Error("failed to init testing account") } - account.Routes[testingRoute.ID] = testingRoute + testingRoute.AccountID = accountID + err = am.SaveRoute(context.Background(), accountID, userID, testingRoute) + require.NoError(t, err, "failed to save testing route") - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("failed to save account") - } - - err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID) + err = am.DeleteRoute(context.Background(), accountID, testingRoute.ID, userID) if err != nil { t.Error("deleting route failed with error: ", err) } - savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Error("failed to retrieve saved account with error: ", err) - } + _, err = am.GetRoute(context.Background(), accountID, testingRoute.ID, userID) + require.NotNil(t, err) - _, found := savedAccount.Routes[testingRoute.ID] - if found { - t.Error("route shouldn't be found after delete") - } + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.NotFound, sErr.Type()) } func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { @@ -1066,16 +1055,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) - if err != nil { - t.Error("failed to init testing account") - } + accountID, err := initTestRouteAccount(t, am) + require.NoError(t, err, "failed to init testing account") newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) + newRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) @@ -1091,7 +1078,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { @@ -1103,21 +1090,21 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { } } - err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID) + err = am.GroupDeletePeer(context.Background(), accountID, groupHA1.ID, peer2ID) require.NoError(t, err) peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes") - err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID) + err = am.GroupDeletePeer(context.Background(), accountID, groupHA2.ID, peer4ID) require.NoError(t, err) peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") - err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID) + err = am.GroupAddPeer(context.Background(), accountID, groupHA2.ID, peer4ID) require.NoError(t, err) peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1128,7 +1115,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes") - err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID) + err = am.DeleteRoute(context.Background(), accountID, newRoute.ID, userID) require.NoError(t, err) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1158,16 +1145,14 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) - if err != nil { - t.Error("failed to init testing account") - } + accountID, err := initTestRouteAccount(t, am) + require.NoError(t, err, "failed to init testing account") newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) + createdRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) require.NoError(t, err) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1181,7 +1166,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { expectedRoute := enabledRoute.Copy() expectedRoute.Peer = peer1Key - err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute) + err = am.SaveRoute(context.Background(), accountID, userID, enabledRoute) require.NoError(t, err) peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1193,7 +1178,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group") - err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID) + err = am.GroupAddPeer(context.Background(), accountID, routeGroup1, peer2ID) require.NoError(t, err) peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID) @@ -1206,10 +1191,10 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { Name: "peer1 group", Peers: []string{peer1ID}, } - err = am.SaveGroup(context.Background(), account.Id, userID, newGroup) + err = am.SaveGroup(context.Background(), accountID, userID, newGroup) require.NoError(t, err) - rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") + rules, err := am.ListPolicies(context.Background(), accountID, "testingUser") require.NoError(t, err) defaultRule := rules[0] @@ -1219,10 +1204,10 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) + err = am.SavePolicy(context.Background(), accountID, userID, newPolicy, false) require.NoError(t, err) - err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) + err = am.DeletePolicy(context.Background(), accountID, defaultRule.ID, userID) require.NoError(t, err) peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1233,7 +1218,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2") - err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID) + err = am.DeleteRoute(context.Background(), accountID, enabledRoute.ID, userID) require.NoError(t, err) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1267,179 +1252,103 @@ func createRouterStore(t *testing.T) (Store, error) { return store, nil } -func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (string, error) { t.Helper() accountID := "testingAcc" domain := "example.com" - account := newAccountWithId(context.Background(), accountID, userID, domain) - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain) if err != nil { - return nil, err + return "", err } - ips := account.getTakenIPs() - peer1IP, err := AllocatePeerIP(account.Network.Net, ips) + createPeer := func(peerID, peerKey, peerName, dnsLabel, kernel, core, platform, os string) (*nbpeer.Peer, error) { + ips, err := am.Store.GetTakenIPs(context.Background(), LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + network, err := am.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + peerIP, err := AllocatePeerIP(network.Net, ips) + if err != nil { + return nil, err + } + + peer := &nbpeer.Peer{ + IP: peerIP, + ID: peerID, + Key: peerKey, + Name: peerName, + DNSLabel: dnsLabel, + UserID: userID, + Meta: nbpeer.PeerSystemMeta{ + Hostname: peerName, + GoOS: strings.ToLower(os), + Kernel: kernel, + Core: core, + Platform: platform, + OS: os, + WtVersion: "development", + UIVersion: "development", + }, + Status: &nbpeer.PeerStatus{}, + } + if err := am.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer); err != nil { + return nil, err + } + return peer, nil + } + + // Create peers + peer1, err := createPeer(peer1ID, peer1Key, "test-host1@netbird.io", "test-host1", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - - peer1 := &nbpeer.Peer{ - IP: peer1IP, - ID: peer1ID, - Key: peer1Key, - Name: "test-host1@netbird.io", - DNSLabel: "test-host1", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host1@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer1.ID] = peer1 - - ips = account.getTakenIPs() - peer2IP, err := AllocatePeerIP(account.Network.Net, ips) + peer2, err := createPeer(peer2ID, peer2Key, "test-host2@netbird.io", "test-host2", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - - peer2 := &nbpeer.Peer{ - IP: peer2IP, - ID: peer2ID, - Key: peer2Key, - Name: "test-host2@netbird.io", - DNSLabel: "test-host2", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host2@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer2.ID] = peer2 - - ips = account.getTakenIPs() - peer3IP, err := AllocatePeerIP(account.Network.Net, ips) + peer3, err := createPeer(peer3ID, peer3Key, "test-host3@netbird.io", "test-host3", "Darwin", "13.4.1", "arm64", "darwin") if err != nil { - return nil, err + return "", err } - - peer3 := &nbpeer.Peer{ - IP: peer3IP, - ID: peer3ID, - Key: peer3Key, - Name: "test-host3@netbird.io", - DNSLabel: "test-host3", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host3@netbird.io", - GoOS: "darwin", - Kernel: "Darwin", - Core: "13.4.1", - Platform: "arm64", - OS: "darwin", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer3.ID] = peer3 - - ips = account.getTakenIPs() - peer4IP, err := AllocatePeerIP(account.Network.Net, ips) + peer4, err := createPeer(peer4ID, peer4Key, "test-host4@netbird.io", "test-host4", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - - peer4 := &nbpeer.Peer{ - IP: peer4IP, - ID: peer4ID, - Key: peer4Key, - Name: "test-host4@netbird.io", - DNSLabel: "test-host4", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host4@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer4.ID] = peer4 - - ips = account.getTakenIPs() - peer5IP, err := AllocatePeerIP(account.Network.Net, ips) + peer5, err := createPeer(peer5ID, peer5Key, "test-host5@netbird.io", "test-host5", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - peer5 := &nbpeer.Peer{ - IP: peer5IP, - ID: peer5ID, - Key: peer5Key, - Name: "test-host5@netbird.io", - DNSLabel: "test-host5", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host5@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, + groupAll, err := am.GetGroupByName(context.Background(), "All", accountID) + if err != nil { + return "", err } - account.Peers[peer5.ID] = peer5 - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - return nil, err - } - groupAll, err := account.GetGroupAll() - if err != nil { - return nil, err - } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID) if err != nil { - return nil, err + return "", err } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID) if err != nil { - return nil, err + return "", err } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID) if err != nil { - return nil, err + return "", err } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID) if err != nil { - return nil, err + return "", err } - newGroup := []*nbgroup.Group{ + newGroups := []*nbgroup.Group{ { ID: routeGroup1, Name: routeGroup1, @@ -1471,15 +1380,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er Peers: []string{peer1.ID, peer4.ID}, }, } - - for _, group := range newGroup { - err = am.SaveGroup(context.Background(), accountID, userID, group) - if err != nil { - return nil, err - } + err = am.SaveGroups(context.Background(), accountID, userID, newGroups) + if err != nil { + return "", err } - return am.Store.GetAccount(context.Background(), account.Id) + return accountID, nil } func TestAccount_getPeersRoutesFirewall(t *testing.T) { @@ -1783,10 +1689,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) { manager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager") - account, err := initTestRouteAccount(t, manager) + accountID, err := initTestRouteAccount(t, manager) require.NoError(t, err, "failed to init testing account") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{ { ID: "groupA", Name: "GroupA", @@ -1832,7 +1738,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { }() _, err := manager.CreateRoute( - context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.Groups, []string{}, true, userID, route.KeepRoute, ) @@ -1868,7 +1774,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { }() _, err := manager.CreateRoute( - context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.Groups, []string{}, true, userID, route.KeepRoute, ) @@ -1904,7 +1810,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { }() newRoute, err := manager.CreateRoute( - context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, + context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, ) @@ -1928,7 +1834,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) + err := manager.SaveRoute(context.Background(), accountID, userID, &baseRoute) require.NoError(t, err) select { @@ -1946,7 +1852,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID) + err := manager.DeleteRoute(context.Background(), accountID, baseRoute.ID, userID) require.NoError(t, err) select { @@ -1970,7 +1876,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { Groups: []string{routeGroup1}, } _, err := manager.CreateRoute( - context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, ) @@ -1982,7 +1888,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, @@ -2010,7 +1916,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { Groups: []string{"groupC"}, } _, err := manager.CreateRoute( - context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, ) @@ -2022,7 +1928,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 2ed8aef95..ebbb5980c 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -25,12 +25,10 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") - if err != nil { - t.Fatal(err) - } + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") + require.NoError(t, err, "failed to get or create account ID") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{ { ID: "group_1", Name: "group_name_1", @@ -49,7 +47,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { expiresIn := time.Hour keyName := "my-test-key" - key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, + key, err := manager.CreateSetupKey(context.Background(), accountID, keyName, SetupKeyReusable, expiresIn, []string{}, SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) @@ -58,7 +56,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { autoGroups := []string{"group_1", "group_2"} newKeyName := "my-new-test-key" revoked := true - newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + newKey, err := manager.SaveSetupKey(context.Background(), accountID, &SetupKey{ Id: key.Id, Name: newKeyName, Revoked: revoked, @@ -72,22 +70,22 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { key.Id, time.Now().UTC(), autoGroups, true) // check the corresponding events that should have been generated - ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked) + ev := getEvent(t, accountID, manager, activity.SetupKeyRevoked) assert.NotNil(t, ev) - assert.Equal(t, account.Id, ev.AccountID) + assert.Equal(t, accountID, ev.AccountID) assert.Equal(t, newKeyName, ev.Meta["name"]) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.NotEmpty(t, ev.Meta["key"]) assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, key.Id, ev.TargetID) - groupAll, err := account.GetGroupAll() - assert.NoError(t, err) + groupAll, err := manager.GetGroupByName(context.Background(), accountID, "All") + require.NoError(t, err) // saving setup key with All group assigned to auto groups should return error autoGroups = append(autoGroups, groupAll.ID) - _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + _, err = manager.SaveSetupKey(context.Background(), accountID, &SetupKey{ Id: key.Id, Name: newKeyName, Revoked: revoked, @@ -103,12 +101,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") - if err != nil { - t.Fatal(err) - } + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") + require.NoError(t, err, "failed to get or create account ID") - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -117,7 +113,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -126,8 +122,8 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - groupAll, err := account.GetGroupAll() - assert.NoError(t, err) + groupAll, err := manager.GetGroupByName(context.Background(), accountID, "All") + require.NoError(t, err) type testCase struct { name string @@ -170,7 +166,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { for _, tCase := range []testCase{testCase1, testCase2, testCase3} { t.Run(tCase.name, func(t *testing.T) { - key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, + key, err := manager.CreateSetupKey(context.Background(), accountID, tCase.expectedKeyName, SetupKeyReusable, expiresIn, tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) if tCase.expectedFailure { @@ -189,10 +185,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated - ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated) + ev := getEvent(t, accountID, manager, activity.SetupKeyCreated) assert.NotNil(t, ev) - assert.Equal(t, account.Id, ev.AccountID) + assert.Equal(t, accountID, ev.AccountID) assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"]) assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"])) assert.NotEmpty(t, ev.Meta["key"]) @@ -208,12 +204,10 @@ func TestGetSetupKeys(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") - if err != nil { - t.Fatal(err) - } + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") + require.NoError(t, err, "failed to get or create account ID") - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -222,7 +216,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index b04060583..bb2a8f15d 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -68,17 +68,23 @@ func TestSqlite_SaveAccount_Large(t *testing.T) { func runLargeTest(t *testing.T, store Store) { t.Helper() - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - groupALL, err := account.GetGroupAll() - if err != nil { - t.Fatal(err) - } + accountID := "account_id" + + err := newAccountWithId(context.Background(), store, accountID, "testuser", "") + assert.NoError(t, err, "failed to create account") + + groupAll, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") + assert.NoError(t, err, "failed to get group All") + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + assert.NoError(t, err, "failed to save setup key") + const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { netIP := randomIPv4() - peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) + peerID := fmt.Sprintf("%s-peer-%d", accountID, n) peer := &nbpeer.Peer{ ID: peerID, @@ -90,16 +96,21 @@ func runLargeTest(t *testing.T, store Store) { Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, SSHEnabled: false, } - account.Peers[peerID] = peer - group, _ := account.GetGroupAll() - group.Peers = append(group.Peers, peerID) - user := &User{ - Id: fmt.Sprintf("%s-user-%d", account.Id, n), - AccountID: account.Id, - } - account.Users[user.Id] = user + err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer) + assert.NoError(t, err, "failed to save peer") + + err = store.AddPeerToAllGroup(context.Background(), accountID, peerID) + assert.NoError(t, err, "failed to add peer to all group") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: fmt.Sprintf("%s-user-%d", accountID, n), + AccountID: accountID, + }) + assert.NoError(t, err, "failed to save user") + route := &route2.Route{ ID: route2.ID(fmt.Sprintf("network-id-%d", n)), + AccountID: accountID, Description: "base route", NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)), Network: netip.MustParsePrefix(netIP.String() + "/24"), @@ -107,22 +118,24 @@ func runLargeTest(t *testing.T, store Store) { Metric: 9999, Masquerade: false, Enabled: true, - Groups: []string{groupALL.ID}, + Groups: []string{groupAll.ID}, } - account.Routes[route.ID] = route + err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route) + assert.NoError(t, err, "failed to save route") - group = &nbgroup.Group{ + group := &nbgroup.Group{ ID: fmt.Sprintf("group-id-%d", n), - AccountID: account.Id, + AccountID: accountID, Name: fmt.Sprintf("group-id-%d", n), Issued: "api", Peers: nil, } - account.Groups[group.ID] = group + err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + assert.NoError(t, err, "failed to save group") nameserver := &nbdns.NameServerGroup{ ID: fmt.Sprintf("nameserver-id-%d", n), - AccountID: account.Id, + AccountID: accountID, Name: fmt.Sprintf("nameserver-id-%d", n), Description: "", NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}}, @@ -132,20 +145,20 @@ func runLargeTest(t *testing.T, store Store) { Enabled: false, SearchDomainsEnabled: false, } - account.NameServerGroups[nameserver.ID] = nameserver + err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameserver) + assert.NoError(t, err, "failed to save nameserver group") - setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey + setupKey, _ = GenerateDefaultSetupKey() + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + assert.NoError(t, err, "failed to save setup key") } - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - if len(store.GetAllAccounts(context.Background())) != 1 { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(context.Background(), account.Id) + a, err := store.GetAccount(context.Background(), accountID) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -213,41 +226,49 @@ func TestSqlite_SaveAccount(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) t.Cleanup(cleanUp) - assert.NoError(t, err) + require.NoError(t, err) + + accountID := "account_id" + err = newAccountWithId(context.Background(), store, accountID, "testuser", "") + require.NoError(t, err, "failed to create account") - account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") + + err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ Key: "peerkey", IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + }) + require.NoError(t, err, "failed to save peer") - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + accountID2 := "account_id2" + err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "") + require.NoError(t, err, "failed to create account") - account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") setupKey, _ = GenerateDefaultSetupKey() - account2.SetupKeys[setupKey.Key] = setupKey - account2.Peers["testpeer2"] = &nbpeer.Peer{ + setupKey.AccountID = accountID2 + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") + + err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID2, &nbpeer.Peer{ Key: "peerkey2", IP: net.IP{127, 0, 0, 2}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name 2", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - - err = store.SaveAccount(context.Background(), account2) - require.NoError(t, err) + }) + require.NoError(t, err, "failed to save peer") if len(store.GetAllAccounts(context.Background())) != 2 { t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(context.Background(), account.Id) + a, err := store.GetAccount(context.Background(), accountID) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -288,36 +309,56 @@ func TestSqlite_DeleteAccount(t *testing.T) { t.Cleanup(cleanUp) assert.NoError(t, err) + accountID := "account_id" testUserID := "testuser" + user := NewAdminUser(testUserID) user.PATs = map[string]*PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", }} - account := newAccountWithId(context.Background(), "account_id", testUserID, "") + err = newAccountWithId(context.Background(), store, accountID, testUserID, "") + require.NoError(t, err) + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") + + err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ Key: "peerkey", IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - account.Users[testUserID] = user + }) + require.NoError(t, err, "failed to save peer") - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: "testtoken", + UserID: testUserID, + Name: "test token", + }) + require.NoError(t, err, "failed to save personal access token") - if len(store.GetAllAccounts(context.Background())) != 1 { + accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthShare) + require.NoError(t, err, "failed to get all account ids") + + if len(accountIDs) != 1 { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "failed to get account") + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts(context.Background())) != 0 { + accountIDs, err = store.GetAllAccountIDs(context.Background(), LockingStrengthShare) + require.NoError(t, err, "failed to get all account ids after DeleteAccount()") + + if len(accountIDs) != 0 { t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") } @@ -714,19 +755,28 @@ func newSqliteStore(t *testing.T) *SqlStore { } func newAccount(store Store, id int) error { - str := fmt.Sprintf("%s-%d", uuid.New().String(), id) - account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") + accountID := fmt.Sprintf("%s-%d", uuid.New().String(), id) + userID := accountID + "-testuser" + + err := newAccountWithId(context.Background(), store, accountID, userID, "example.com") + if err != nil { + return err + } + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["p"+str] = &nbpeer.Peer{ - Key: "peerkey" + str, + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + if err != nil { + return err + } + + return store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ + Key: accountID + "-peerkey", IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - - return store.SaveAccount(context.Background(), account) + }) } func TestPostgresql_NewStore(t *testing.T) { @@ -754,39 +804,52 @@ func TestPostgresql_SaveAccount(t *testing.T) { t.Cleanup(cleanUp) assert.NoError(t, err) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") + accountID := "account_id" + + err = newAccountWithId(context.Background(), store, accountID, "testuser", "") + require.NoError(t, err, "failed to create account") + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") + + err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ Key: "peerkey", IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + }) + require.NoError(t, err, "failed to save peer") - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + accountID2 := "account_id2" + + err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "") + require.NoError(t, err, "failed to create account") - account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") setupKey, _ = GenerateDefaultSetupKey() - account2.SetupKeys[setupKey.Key] = setupKey - account2.Peers["testpeer2"] = &nbpeer.Peer{ + setupKey.AccountID = accountID2 + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") + + err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID2, &nbpeer.Peer{ Key: "peerkey2", IP: net.IP{127, 0, 0, 2}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name 2", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + }) + require.NoError(t, err, "failed to save peer") - err = store.SaveAccount(context.Background(), account2) - require.NoError(t, err) + accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthUpdate) + require.NoError(t, err, "failed to get all account ids") - if len(store.GetAllAccounts(context.Background())) != 2 { + if len(accountIDs) != 2 { t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(context.Background(), account.Id) + a, err := store.GetAccount(context.Background(), accountID) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -827,32 +890,49 @@ func TestPostgresql_DeleteAccount(t *testing.T) { t.Cleanup(cleanUp) assert.NoError(t, err) + accountID := "account_id" testUserID := "testuser" + user := NewAdminUser(testUserID) user.PATs = map[string]*PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", }} - account := newAccountWithId(context.Background(), "account_id", testUserID, "") + err = newAccountWithId(context.Background(), store, accountID, testUserID, "") + require.NoError(t, err, "failed to create account") + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") + + err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ Key: "peerkey", IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - account.Users[testUserID] = user + }) + require.NoError(t, err, "failed to save peer") - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: "testtoken", + UserID: testUserID, + Name: "test token", + }) + require.NoError(t, err, "failed to save personal access token") - if len(store.GetAllAccounts(context.Background())) != 1 { + accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthUpdate) + require.NoError(t, err, "failed to get all account ids") + + if len(accountIDs) != 1 { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "failed to get account") + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) @@ -1218,7 +1298,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { domain := "example.com" category := "public" IsDomainPrimaryAccount := false - err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount) require.NoError(t, err) account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) @@ -1232,7 +1312,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { domain := "test.com" category := "private" IsDomainPrimaryAccount := true - err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount) require.NoError(t, err) account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) @@ -1246,7 +1326,9 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { domain := "test.com" category := "private" IsDomainPrimaryAccount := true - err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) + err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, "non-existing-account-id", + domain, category, &IsDomainPrimaryAccount, + ) require.Error(t, err) }) @@ -1274,7 +1356,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) + err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID) require.NoError(t, err) _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) @@ -1290,6 +1372,6 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" nonExistingKeyID := "non-existing-key-id" - err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) + err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } diff --git a/management/server/user.go b/management/server/user.go index fff7e8aec..e7a3708e4 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -962,7 +962,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountIDByUser(ctx context.Context, } } - return "", nil + return accountID, nil } // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return diff --git a/management/server/user_test.go b/management/server/user_test.go index 91a9d6245..709c3c5a5 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -43,12 +43,9 @@ const ( func TestUser_CreatePAT_ForSameUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -84,15 +81,16 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockTargetUserId, + AccountID: mockAccountID, IsServiceUser: false, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -106,15 +104,16 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { func TestUser_CreatePAT_ForServiceUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockTargetUserId, + AccountID: mockAccountID, IsServiceUser: true, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -132,12 +131,9 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -151,12 +147,9 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { func TestUser_CreatePAT_WithEmptyName(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -164,26 +157,22 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { } _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) - assert.Errorf(t, err, "Wrong expiration should thorw error") + assert.Errorf(t, err, "Wrong expiration should throw error") } func TestUser_DeletePAT(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ - Id: mockUserID, - PATs: map[string]*PersonalAccessToken{ - mockTokenID1: { - ID: mockTokenID1, - HashedToken: mockToken1, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID1, + UserID: mockUserID, + HashedToken: mockToken1, + }) + assert.NoError(t, err, "failed to create PAT") am := DefaultAccountManager{ Store: store, @@ -195,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) { t.Fatalf("Error when adding PAT to user: %s", err) } - account, err = store.GetAccount(context.Background(), mockAccountID) + account, err := store.GetAccount(context.Background(), mockAccountID) if err != nil { t.Fatalf("Error when getting account: %s", err) } @@ -206,21 +195,16 @@ func TestUser_DeletePAT(t *testing.T) { func TestUser_GetPAT(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ - Id: mockUserID, - AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ - mockTokenID1: { - ID: mockTokenID1, - HashedToken: mockToken1, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID1, + UserID: mockUserID, + HashedToken: mockToken1, + }) + assert.NoError(t, err, "failed to create PAT") am := DefaultAccountManager{ Store: store, @@ -239,25 +223,23 @@ func TestUser_GetPAT(t *testing.T) { func TestUser_GetAllPATs(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ - Id: mockUserID, - AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ - mockTokenID1: { - ID: mockTokenID1, - HashedToken: mockToken1, - }, - mockTokenID2: { - ID: mockTokenID2, - HashedToken: mockToken2, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID1, + UserID: mockUserID, + HashedToken: mockToken1, + }) + assert.NoError(t, err, "failed to create PAT") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID2, + UserID: mockUserID, + HashedToken: mockToken2, + }) + assert.NoError(t, err, "failed to create PAT") am := DefaultAccountManager{ Store: store, @@ -342,12 +324,9 @@ func validateStruct(s interface{}) (err error) { func TestUser_CreateServiceUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -359,7 +338,7 @@ func TestUser_CreateServiceUser(t *testing.T) { t.Fatalf("Error when creating service user: %s", err) } - account, err = store.GetAccount(context.Background(), mockAccountID) + account, err := store.GetAccount(context.Background(), mockAccountID) assert.NoError(t, err) assert.Equal(t, 2, len(account.Users)) @@ -383,12 +362,9 @@ func TestUser_CreateServiceUser(t *testing.T) { func TestUser_CreateUser_ServiceUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -406,7 +382,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { t.Fatalf("Error when creating user: %s", err) } - account, err = store.GetAccount(context.Background(), mockAccountID) + account, err := store.GetAccount(context.Background(), mockAccountID) assert.NoError(t, err) assert.True(t, user.IsServiceUser) @@ -425,12 +401,9 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { func TestUser_CreateUser_RegularUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -450,12 +423,9 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { func TestUser_InviteNewUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -549,13 +519,12 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = tt.serviceUser - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, tt.serviceUser) + assert.NoError(t, err, "failed to create service user") am := DefaultAccountManager{ Store: store, @@ -582,12 +551,9 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { func TestUser_DeleteUser_SelfDelete(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -603,39 +569,38 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { func TestUser_DeleteUser_regularUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - targetId := "user2" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: true, - ServiceUserName: "user2username", - } - targetId = "user3" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - } - targetId = "user4" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedIntegration, - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") - targetId = "user5" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{ + { + Id: "user2", + AccountID: mockAccountID, + IsServiceUser: true, + ServiceUserName: "user2username", + }, + { + Id: "user3", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user4", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedIntegration, + }, + { + Id: "user5", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleOwner, + }, + }) + assert.NoError(t, err, "failed to save users") am := DefaultAccountManager{ Store: store, @@ -685,61 +650,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { func TestUser_DeleteUser_RegularUsers(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - targetId := "user2" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: true, - ServiceUserName: "user2username", - } - targetId = "user3" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - } - targetId = "user4" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedIntegration, - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") - targetId = "user5" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, - } - account.Users["user6"] = &User{ - Id: "user6", - IsServiceUser: false, - Issued: UserIssuedAPI, - } - account.Users["user7"] = &User{ - Id: "user7", - IsServiceUser: false, - Issued: UserIssuedAPI, - } - account.Users["user8"] = &User{ - Id: "user8", - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, - } - account.Users["user9"] = &User{ - Id: "user9", - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{ + { + Id: "user2", + AccountID: mockAccountID, + IsServiceUser: true, + ServiceUserName: "user2username", + }, + { + Id: "user3", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user4", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedIntegration, + }, + { + Id: "user5", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleOwner, + }, + { + Id: "user6", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user7", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user8", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleAdmin, + }, + { + Id: "user9", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleAdmin, + }, + }) + assert.NoError(t, err) am := DefaultAccountManager{ Store: store, @@ -816,7 +784,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - acc, err := am.Store.GetAccount(context.Background(), account.Id) + acc, err := am.Store.GetAccount(context.Background(), mockAccountID) assert.NoError(t, err) for _, id := range tc.expectedDeleted { @@ -836,12 +804,9 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -865,14 +830,19 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_ListUsers(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewRegularUser("normal_user1") - account.Users["normal_user2"] = NewRegularUser("normal_user2") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + newUser := NewRegularUser("normal_user1") + newUser.AccountID = mockAccountID + err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser) + assert.NoError(t, err, "failed to create user") + + newUser = NewRegularUser("normal_user2") + newUser.AccountID = mockAccountID + err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -946,15 +916,24 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) - account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings - delete(account.Users, mockUserID) - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + newUser := NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser) + assert.NoError(t, err, "failed to create user") + + settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, mockAccountID) + assert.NoError(t, err, "failed to get account settings") + + settings.RegularUsersViewBlocked = testCase.limitedViewSettings + + err = store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, mockAccountID, settings) + assert.NoError(t, err, "failed to save account settings") + + err = store.DeleteUser(context.Background(), LockingStrengthUpdate, mockAccountID, mockUserID) + assert.NoError(t, err, "failed to delete user") am := DefaultAccountManager{ Store: store, @@ -968,7 +947,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { assert.Equal(t, 1, len(users)) - userInfo, _ := users[0].ToUserInfo(nil, account.Settings) + userInfo, _ := users[0].ToUserInfo(nil, settings) assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) }) } @@ -978,22 +957,21 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { func TestDefaultAccountManager_ExternalCache(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - externalUser := &User{ - Id: "externalUser", - Role: UserRoleUser, - Issued: UserIssuedIntegration, + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: "externalUser", + AccountID: mockAccountID, + Role: UserRoleUser, + Issued: UserIssuedIntegration, IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", }, - } - account.Users[externalUser.Id] = externalUser - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -1013,6 +991,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { assert.NoError(t, err) cacheManager := am.GetExternalCacheManager() + + externalUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, "externalUser") + assert.NoError(t, err, "failed to get user") + cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id) err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}) assert.NoError(t, err) @@ -1042,17 +1024,17 @@ func TestUser_IsAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockServiceUserID, + AccountID: mockAccountID, Role: "user", IsServiceUser: true, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -1071,17 +1053,16 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockServiceUserID, + AccountID: mockAccountID, Role: "user", IsServiceUser: true, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -1240,21 +1221,30 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // create an account and an admin user - account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), ownerUserID, "netbird.io") if err != nil { t.Fatal(err) } // create other users - account.Users[regularUserID] = NewRegularUser(regularUserID) - account.Users[adminUserID] = NewAdminUser(adminUserID) - account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} - err = manager.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) + regularUser := NewRegularUser(regularUserID) + regularUser.AccountID = accountID + + adminUser := NewAdminUser(adminUserID) + adminUser.AccountID = accountID + + serviceUser := &User{ + Id: serviceUserID, + AccountID: accountID, + IsServiceUser: true, + Role: UserRoleAdmin, + ServiceUserName: "service", } - updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update) + err = manager.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{regularUser, adminUser, serviceUser}) + assert.NoError(t, err, "failed to save users") + + updated, err := manager.SaveUser(context.Background(), accountID, tc.initiatorID, tc.update) if tc.expectedErr { require.Errorf(t, err, "expecting SaveUser to throw an error") } else {