diff --git a/management/server/account_test.go b/management/server/account_test.go index 1e1a4c357..4bf738523 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -401,7 +401,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } for _, testCase := range tt { - account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io") + store := newStore(t) + + err := newAccountWithId(context.Background(), store, "account-1", userID, "netbird.io") + require.NoError(t, err, "failed to create account") + + account, err := store.GetAccount(context.Background(), "account-1") + require.NoError(t, err, "failed to get account") + account.UpdateSettings(&testCase.accountSettings) account.Network = network account.Peers = testCase.peers @@ -419,6 +426,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) + + store.Close(context.Background()) } } @@ -426,27 +435,35 @@ func TestNewAccount(t *testing.T) { domain := "netbird.io" userId := "account_creator" accountID := "account_id" - account := newAccountWithId(context.Background(), accountID, userId, domain) + + store := newStore(t) + defer store.Close(context.Background()) + + err := newAccountWithId(context.Background(), store, accountID, userId, domain) + require.NoError(t, err, "failed to create account") + + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "failed to get account") verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) } -func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { +func TestAccountManager_GetOrCreateAccountIDByUser(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) return } - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - if account == nil { + if accountID == "" { t.Fatalf("expected to create an account for a user %s", userID) return } - account, err = manager.Store.GetAccountByUser(context.Background(), userID) + account, err := manager.Store.GetAccountByUser(context.Background(), userID) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID) return @@ -669,15 +686,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { userId := "user-id" domain := "test.domain" - _ = newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") - // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountIDByUserID where the id is getting generated - // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") @@ -693,44 +707,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "get account failed") + accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account groups") - require.Len(t, account.Groups, 1, "only ALL group should exists") + require.Len(t, accountGroups, 1, "only ALL group should exists") }) t.Run("JWT groups enabled without claim name", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true - err := manager.Store.SaveAccount(context.Background(), initAccount) - require.NoError(t, err, "save account failed") - require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings) + require.NoError(t, err, "failed to update account settings") + + totalAccounts, err := manager.Store.GetTotalAccounts(context.Background()) + require.NoError(t, err, "failed to get total accounts") + require.Equal(t, int64(1), totalAccounts, "only one account should exist") accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "get account failed") + accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account groups") - require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") + require.Len(t, accountGroups, 1, "if group claim is not set no group added from JWT") }) t.Run("JWT groups enabled", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsClaimName = "idp-groups" - err := manager.Store.SaveAccount(context.Background(), initAccount) - require.NoError(t, err, "save account failed") - require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings) + require.NoError(t, err, "failed to update account settings") + + totalAccounts, err := manager.Store.GetTotalAccounts(context.Background()) + require.NoError(t, err, "failed to get total accounts") + require.Equal(t, int64(1), totalAccounts, "only one account should exist") accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "get account failed") + exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to check account existence") + require.True(t, exists, "account should exist") - require.Len(t, account.Groups, 3, "groups should be added to the account") + accountGroups, err := manager.GetAllGroups(context.Background(), accountID, userId) + require.NoError(t, err, "failed to get account groups") + require.Len(t, accountGroups, 3, "groups should be added to the account") groupsByNames := map[string]*group.Group{} - for _, g := range account.Groups { + for _, g := range accountGroups { groupsByNames[g.Name] = g } @@ -746,27 +769,23 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { }) } -func TestAccountManager_GetAccountFromPAT(t *testing.T) { +func TestAccountManager_GetAccountInfoFromPAT(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") + err := newAccountWithId(context.Background(), store, "account_id", "testuser", "") + require.NoError(t, err, "failed to create account") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ - Id: "someUser", - PATs: map[string]*PersonalAccessToken{ - "tokenId": { - ID: "tokenId", - UserID: "someUser", - HashedToken: encodedHashedToken, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) + + userPAT := &PersonalAccessToken{ + ID: "tokenId", + UserID: "testuser", + HashedToken: encodedHashedToken, + CreatedAt: time.Now().UTC(), } + err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT) + require.NoError(t, err, "failed to save PAT") am := DefaultAccountManager{ Store: store, @@ -778,31 +797,27 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { } assert.Equal(t, "account_id", user.AccountID) - assert.Equal(t, "someUser", user.Id) - assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID) + assert.Equal(t, "testuser", user.Id) + assert.Equal(t, userPAT, pat) } func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") + err := newAccountWithId(context.Background(), store, "account_id", "testuser", "") + require.NoError(t, err, "failed to create account") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ - Id: "someUser", - PATs: map[string]*PersonalAccessToken{ - "tokenId": { - ID: "tokenId", - HashedToken: encodedHashedToken, - LastUsed: time.Time{}, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) + + userPAT := &PersonalAccessToken{ + ID: "tokenId", + UserID: "someUser", + HashedToken: encodedHashedToken, + LastUsed: time.Time{}, } + err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT) + require.NoError(t, err, "failed to save PAT") am := DefaultAccountManager{ Store: store, @@ -813,11 +828,10 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { t.Fatalf("Error when marking PAT used: %s", err) } - account, err = am.Store.GetAccount(context.Background(), "account_id") - if err != nil { - t.Fatalf("Error when getting account: %s", err) - } - assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero()) + userPAT, err = store.GetPATByID(context.Background(), LockingStrengthShare, userPAT.UserID, userPAT.ID) + require.NoError(t, err, "failed to get PAT") + + assert.True(t, !userPAT.LastUsed.IsZero()) } func TestAccountManager_PrivateAccount(t *testing.T) { @@ -828,15 +842,15 @@ func TestAccountManager_PrivateAccount(t *testing.T) { } userId := "test_user" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, "") if err != nil { t.Fatal(err) } - if account == nil { + if accountID == "" { t.Fatalf("expected to create an account for a user %s", userId) } - account, err = manager.Store.GetAccountByUser(context.Background(), userId) + account, err := manager.Store.GetAccountByUser(context.Background(), userId) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) } @@ -855,32 +869,22 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { userId := "test_user" domain := "hotmail.com" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain) - if err != nil { - t.Fatal(err) - } - if account == nil { - t.Fatalf("expected to create an account for a user %s", userId) - } + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain) + require.NoError(t, err, "failed to get or create account by user") + require.NotEmptyf(t, accountID, "expected to create an account for a user %s", userId) - if account != nil && account.Domain != domain { - t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) - } + accDomain, _, err := manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account domain and category") + require.Equal(t, domain, accDomain, "expected account domain to match") domain = "gmail.com" - account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain) - if err != nil { - t.Fatalf("got the following error while retrieving existing acc: %v", err) - } + accountID, err = manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain) + require.NoError(t, err, "failed to get or create account by user") - if account == nil { - t.Fatalf("expected to get an account for a user %s", userId) - } - - if account != nil && account.Domain != domain { - t.Errorf("updating domain. expected %s got %s", domain, account.Domain) - } + accDomain, _, err = manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account domain and category") + require.Equal(t, domain, accDomain, "expected account domain to match") } func TestAccountManager_GetAccountByUserID(t *testing.T) { @@ -912,12 +916,11 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { } func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { - account := newAccountWithId(context.Background(), accountID, userID, domain) - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain) if err != nil { return nil, err } - return account, nil + return am.Store.GetAccount(context.Background(), accountID) } func TestAccountManager_GetAccount(t *testing.T) { @@ -1164,23 +1167,18 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { return } - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud") - if err != nil { - t.Fatal(err) - } + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "netbird.cloud") + require.NoError(t, err, "failed to get or create account by user") - serial := account.Network.CurrentSerial() // should be 0 + network, err := manager.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account network") - if account.Network.Serial != 0 { - t.Errorf("expecting account network to have an initial Serial=0") - return - } + serial := network.CurrentSerial() // should be 0 + require.Equal(t, 0, int(serial), "expected account network to have an initial Serial=0") key, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Fatal(err) - return - } + require.NoError(t, err, "failed to generate private key") + expectedPeerKey := key.PublicKey().String() expectedUserID := userID @@ -1188,16 +1186,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) - if err != nil { - t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy) - return - } + require.NoError(t, err, "failed to add peer") - account, err = manager.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - return - } + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "failed to get account") if peer.Key != expectedPeerKey { t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key) diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 8a66da96c..b50a67594 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -39,12 +39,12 @@ func TestGetDNSSettings(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestDNSAccount(t, am) + accountID, err := initTestDNSAccount(t, am) if err != nil { t.Fatal("failed to init testing account") } - dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) + dnsSettings, err := am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID) if err != nil { t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) } @@ -53,16 +53,12 @@ func TestGetDNSSettings(t *testing.T) { t.Fatal("DNS settings for new accounts shouldn't return nil") } - account.DNSSettings = DNSSettings{ + err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &DNSSettings{ DisabledManagementGroups: []string{group1ID}, - } + }) + require.NoError(t, err, "failed to update DNS settings") - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("failed to save testing account with new DNS settings") - } - - dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) + dnsSettings, err = am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID) if err != nil { t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) } @@ -71,7 +67,7 @@ func TestGetDNSSettings(t *testing.T) { t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) } - _, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID) + _, err = am.GetDNSSettings(context.Background(), accountID, dnsRegularUserID) if err == nil { t.Errorf("An error should be returned when getting the DNS settings with a regular user") } @@ -126,12 +122,12 @@ func TestSaveDNSSettings(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestDNSAccount(t, am) + accountID, err := initTestDNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) + err = am.SaveDNSSettings(context.Background(), accountID, testCase.userID, testCase.inputSettings) if err != nil { if testCase.shouldFail { return @@ -139,7 +135,7 @@ func TestSaveDNSSettings(t *testing.T) { t.Error(err) } - updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id) + updatedAccount, err := am.Store.GetAccount(context.Background(), accountID) if err != nil { t.Errorf("should be able to retrieve updated account, got err: %s", err) } @@ -158,17 +154,17 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestDNSAccount(t, am) + accountID, err := initTestDNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - peer1, err := account.FindPeerByPubKey(dnsPeer1Key) + peer1, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer1Key) if err != nil { t.Error("failed to init testing account") } - peer2, err := account.FindPeerByPubKey(dnsPeer2Key) + peer2, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer2Key) if err != nil { t.Error("failed to init testing account") } @@ -179,11 +175,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) { require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group") - dnsSettings := account.DNSSettings.Copy() + accountDNSSettings, err := am.Store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account DNS settings") + + dnsSettings := accountDNSSettings.Copy() dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) - account.DNSSettings = dnsSettings - err = am.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &dnsSettings) + require.NoError(t, err, "failed to update DNS settings") updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) require.NoError(t, err) @@ -222,7 +220,7 @@ func createDNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: dnsPeer1Key, @@ -257,64 +255,65 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro domain := "example.com" - account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) - - account.Users[dnsRegularUserID] = &User{ - Id: dnsRegularUserID, - Role: UserRoleUser, + err := newAccountWithId(context.Background(), am.Store, dnsAccountID, dnsAdminUserID, domain) + if err != nil { + return "", err } - err := am.Store.SaveAccount(context.Background(), account) + err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: dnsRegularUserID, + AccountID: dnsAccountID, + Role: UserRoleUser, + }) if err != nil { - return nil, err + return "", err } savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) if err != nil { - return nil, err + return "", err } _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) if err != nil { - return nil, err + return "", err } - account, err = am.Store.GetAccount(context.Background(), account.Id) + peer1, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer1.Key) if err != nil { - return nil, err + return "", err } - peer1, err = account.FindPeerByPubKey(peer1.Key) + _, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer2.Key) if err != nil { - return nil, err + return "", err } - _, err = account.FindPeerByPubKey(peer2.Key) + err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{ + { + ID: dnsGroup1ID, + AccountID: dnsAccountID, + Peers: []string{peer1.ID}, + Name: dnsGroup1ID, + }, + { + ID: dnsGroup2ID, + AccountID: dnsAccountID, + Name: dnsGroup2ID, + }, + }) if err != nil { - return nil, err + return "", err } - newGroup1 := &group.Group{ - ID: dnsGroup1ID, - Peers: []string{peer1.ID}, - Name: dnsGroup1ID, - } - - newGroup2 := &group.Group{ - ID: dnsGroup2ID, - Name: dnsGroup2ID, - } - - account.Groups[newGroup1.ID] = newGroup1 - account.Groups[newGroup2.ID] = newGroup2 - - allGroup, err := account.GetGroupAll() + allGroup, err := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, dnsAccountID, "All") if err != nil { - return nil, err + return "", err } - account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{ - ID: dnsNSGroup1, - Name: "ns-group-1", + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &dns.NameServerGroup{ + ID: dnsNSGroup1, + AccountID: dnsAccountID, + Name: "ns-group-1", NameServers: []dns.NameServer{{ IP: netip.MustParseAddr(savedPeer1.IP.String()), NSType: dns.UDPNameServerType, @@ -323,14 +322,12 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro Primary: true, Enabled: true, Groups: []string{allGroup.ID}, - } - - err = am.Store.SaveAccount(context.Background(), account) + }) if err != nil { - return nil, err + return "", err } - return am.Store.GetAccount(context.Background(), account.Id) + return dnsAccountID, nil } func generateTestData(size int) nbdns.Config { diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 00e5d777a..eae70dae5 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -7,21 +7,12 @@ import ( "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/stretchr/testify/require" ) type MockStore struct { Store - account *Account -} - -func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ LockingStrength) ([]*nbpeer.Peer, error) { - var peers []*nbpeer.Peer - for _, v := range s.account.Peers { - if v.Ephemeral { - peers = append(peers, v) - } - } - return peers, nil + accountID string } type MocAccountManager struct { @@ -29,9 +20,8 @@ type MocAccountManager struct { store *MockStore } -func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { - delete(a.store.account.Peers, peerID) - return nil //nolint:nil +func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, _ string) error { + return a.store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID) } func TestNewManager(t *testing.T) { @@ -40,23 +30,26 @@ func TestNewManager(t *testing.T) { return startTime } - store := &MockStore{} + store := &MockStore{ + Store: newStore(t), + } am := MocAccountManager{ store: store, } numberOfPeers := 5 numberOfEphemeralPeers := 3 - seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + require.NoError(t, err, "failed to seed peers") mgr := NewEphemeralManager(store, am) mgr.loadEphemeralPeers(context.Background()) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) - if len(store.account.Peers) != numberOfPeers { - t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers)) - } + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") + require.Equal(t, numberOfPeers, len(peers), "failed to cleanup ephemeral peers") } func TestNewManagerPeerConnected(t *testing.T) { @@ -65,26 +58,32 @@ func TestNewManagerPeerConnected(t *testing.T) { return startTime } - store := &MockStore{} + store := &MockStore{ + Store: newStore(t), + } am := MocAccountManager{ store: store, } numberOfPeers := 5 numberOfEphemeralPeers := 3 - seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + require.NoError(t, err, "failed to seed peers") mgr := NewEphemeralManager(store, am) mgr.loadEphemeralPeers(context.Background()) - mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) + + peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0") + require.NoError(t, err, "failed to get peer") + + mgr.OnPeerConnected(context.Background(), peer) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) - expected := numberOfPeers + 1 - if len(store.account.Peers) != expected { - t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) - } + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") + require.Equal(t, numberOfPeers+1, len(peers), "failed to cleanup ephemeral peers") } func TestNewManagerPeerDisconnected(t *testing.T) { @@ -93,50 +92,73 @@ func TestNewManagerPeerDisconnected(t *testing.T) { return startTime } - store := &MockStore{} + store := &MockStore{ + Store: newStore(t), + } am := MocAccountManager{ store: store, } numberOfPeers := 5 numberOfEphemeralPeers := 3 - seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + require.NoError(t, err, "failed to seed peers") mgr := NewEphemeralManager(store, am) mgr.loadEphemeralPeers(context.Background()) - for _, v := range store.account.Peers { - mgr.OnPeerConnected(context.Background(), v) + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") + for _, v := range peers { + mgr.OnPeerConnected(context.Background(), v) } - mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) + + peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0") + require.NoError(t, err, "failed to get peer") + mgr.OnPeerDisconnected(context.Background(), peer) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) + peers, err = store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") expected := numberOfPeers + numberOfEphemeralPeers - 1 - if len(store.account.Peers) != expected { - t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) - } + require.Equal(t, expected, len(peers), "failed to cleanup ephemeral peers") } -func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { - store.account = newAccountWithId(context.Background(), "my account", "", "") +func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) error { + accountID := "my account" + err := newAccountWithId(context.Background(), store, accountID, "", "") + if err != nil { + return err + } + store.accountID = accountID for i := 0; i < numberOfPeers; i++ { peerId := fmt.Sprintf("peer_%d", i) p := &nbpeer.Peer{ ID: peerId, + AccountID: accountID, Ephemeral: false, } - store.account.Peers[p.ID] = p + err = store.AddPeerToAccount(context.Background(), p) + if err != nil { + return err + } } for i := 0; i < numberOfEphemeralPeers; i++ { peerId := fmt.Sprintf("ephemeral_peer_%d", i) p := &nbpeer.Peer{ ID: peerId, + AccountID: accountID, Ephemeral: true, } - store.account.Peers[p.ID] = p + err = store.AddPeerToAccount(context.Background(), p) + if err != nil { + return err + } } + + return nil } diff --git a/management/server/group_test.go b/management/server/group_test.go index 0515b9698..a4a85876d 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -328,25 +328,30 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A } routeResource := &route.Route{ - ID: "example route", - Groups: []string{groupForRoute.ID}, + ID: "example route", + AccountID: accountID, + Groups: []string{groupForRoute.ID}, } routePeerGroupResource := &route.Route{ ID: "example route with peer groups", + AccountID: accountID, PeerGroups: []string{groupForRoute2.ID}, } nameServerGroup := &nbdns.NameServerGroup{ - ID: "example name server group", - Groups: []string{groupForNameServerGroups.ID}, + ID: "example name server group", + AccountID: accountID, + Groups: []string{groupForNameServerGroups.ID}, } policy := &Policy{ - ID: "example policy", + ID: "example policy", + AccountID: accountID, Rules: []*PolicyRule{ { ID: "example policy rule", + PolicyID: "example policy", Destinations: []string{groupForPolicies.ID}, }, }, @@ -354,35 +359,60 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A setupKey := &SetupKey{ Id: "example setup key", + AccountID: accountID, AutoGroups: []string{groupForSetupKeys.ID}, } user := &User{ Id: "example user", + AccountID: accountID, AutoGroups: []string{groupForUsers.ID}, } - account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) - account.Routes[routeResource.ID] = routeResource - account.Routes[routePeerGroupResource.ID] = routePeerGroupResource - account.NameServerGroups[nameServerGroup.ID] = nameServerGroup - account.Policies = append(account.Policies, policy) - account.SetupKeys[setupKey.Id] = setupKey - account.Users[user.Id] = user - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain) if err != nil { return nil, nil, err } - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routeResource) + if err != nil { + return nil, nil, err + } - acc, err := am.Store.GetAccount(context.Background(), account.Id) + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routePeerGroupResource) + if err != nil { + return nil, nil, err + } + + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameServerGroup) + if err != nil { + return nil, nil, err + } + + err = am.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + if err != nil { + return nil, nil, err + } + + err = am.Store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + if err != nil { + return nil, nil, err + } + + err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, user) + if err != nil { + return nil, nil, err + } + + err = am.SaveGroups(context.Background(), accountID, groupAdminUserID, []*nbgroup.Group{ + groupForRoute, groupForRoute2, groupForNameServerGroups, groupForPolicies, + groupForSetupKeys, groupForUsers, groupForUsers, groupForIntegration, + }) + if err != nil { + return nil, nil, err + } + + acc, err := am.Store.GetAccount(context.Background(), accountID) if err != nil { return nil, nil, err } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 57ad968b3..dc8765e19 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -246,7 +246,7 @@ func Test_SyncProtocol(t *testing.T) { t.Fatal("expecting SyncResponse to have non-nil NetworkMap") } - if len(networkMap.GetRemotePeers()) != 4 { + if len(networkMap.GetRemotePeers()) != 3 { t.Fatalf("expecting SyncResponse to have NetworkMap with 3 remote peers, got %d", len(networkMap.GetRemotePeers())) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 846dbf023..6a305e723 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/management/server/status" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -381,14 +382,14 @@ func TestCreateNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } outNSGroup, err := am.CreateNameServerGroup( context.Background(), - account.Id, + accountID, testCase.inputArgs.name, testCase.inputArgs.description, testCase.inputArgs.nameServers, @@ -609,20 +610,16 @@ func TestSaveNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup - - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("account should be saved") - } + testCase.existingNSGroup.AccountID = accountID + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testCase.existingNSGroup) + require.NoError(t, err, "failed to save existing nameserver group") var nsGroupToSave *nbdns.NameServerGroup - if !testCase.skipCopying { nsGroupToSave = testCase.existingNSGroup.Copy() @@ -651,22 +648,17 @@ func TestSaveNameServerGroup(t *testing.T) { } } - err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave) - + err = am.SaveNameServerGroup(context.Background(), accountID, userID, nsGroupToSave) testCase.errFunc(t, err) if !testCase.shouldCreate { return } - account, err = am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - } - - savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID] - require.True(t, saved) + savedNSGroup, err := am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testCase.expectedNSGroup.ID) + require.NoError(t, err, "failed to get saved nameserver group") + testCase.expectedNSGroup.AccountID = accountID if !testCase.expectedNSGroup.IsEqual(savedNSGroup) { t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup) } @@ -703,32 +695,25 @@ func TestDeleteNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - account.NameServerGroups[testingNSGroup.ID] = testingNSGroup + testingNSGroup.AccountID = accountID + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testingNSGroup) + require.NoError(t, err, "failed to save nameserver group") - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("failed to save account") - } - - err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID) + err = am.DeleteNameServerGroup(context.Background(), accountID, testingNSGroup.ID, userID) if err != nil { t.Error("deleting nameserver group failed with error: ", err) } - savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Error("failed to retrieve saved account with error: ", err) - } - - _, found := savedAccount.NameServerGroups[testingNSGroup.ID] - if found { - t.Error("nameserver group shouldn't be found after delete") - } + _, err = am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testingNSGroup.ID) + require.NotNil(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok, "error should be a status error") + assert.Equal(t, status.NotFound, sErr.Type(), "nameserver group shouldn't be found after delete") } func TestGetNameServerGroup(t *testing.T) { @@ -738,12 +723,12 @@ func TestGetNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID) + foundGroup, err := am.GetNameServerGroup(context.Background(), accountID, testUserID, existingNSGroupID) if err != nil { t.Error("getting existing nameserver group failed with error: ", err) } @@ -752,7 +737,7 @@ func TestGetNameServerGroup(t *testing.T) { t.Error("got a nil group while getting nameserver group with ID") } - _, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing") + _, err = am.GetNameServerGroup(context.Background(), accountID, testUserID, "not existing") if err == nil { t.Error("getting not existing nameserver group should return error, got nil") } @@ -784,8 +769,12 @@ func createNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) { t.Helper() + accountID := "testingAcc" + userID := testUserID + domain := "example.com" + peer1 := &nbpeer.Peer{ Key: nsGroupPeer1Key, Name: "test-host1@netbird.io", @@ -816,6 +805,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error } existingNSGroup := nbdns.NameServerGroup{ ID: existingNSGroupID, + AccountID: accountID, Name: existingNSGroupName, Description: "", NameServers: []nbdns.NameServer{ @@ -834,42 +824,42 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error Enabled: true, } - accountID := "testingAcc" - userID := testUserID - domain := "example.com" - - account := newAccountWithId(context.Background(), accountID, userID, domain) - - account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - - newGroup1 := &nbgroup.Group{ - ID: group1ID, - Name: group1ID, - } - - newGroup2 := &nbgroup.Group{ - ID: group2ID, - Name: group2ID, - } - - account.Groups[newGroup1.ID] = newGroup1 - account.Groups[newGroup2.ID] = newGroup2 - - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain) if err != nil { - return nil, err + return "", err + } + + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &existingNSGroup) + if err != nil { + return "", err + } + + err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*nbgroup.Group{ + { + ID: group1ID, + AccountID: accountID, + Name: group1ID, + }, + { + ID: group2ID, + AccountID: accountID, + Name: group2ID, + }, + }) + if err != nil { + return "", err } _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) if err != nil { - return nil, err + return "", err } _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) if err != nil { - return nil, err + return "", err } - return account, nil + return accountID, nil } func TestValidateDomain(t *testing.T) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 0e30a3762..fc63156c3 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -468,21 +468,25 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ - Id: someUser, - Role: UserRoleUser, - } - account.Settings.RegularUsersViewBlocked = false + err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "") + require.NoError(t, err, "failed to create account") - err = manager.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) - return - } + err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: someUser, + AccountID: accountID, + Role: UserRoleUser, + }) + require.NoError(t, err, "failed to create user") + + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account settings") + + settings.RegularUsersViewBlocked = false + err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings) + require.NoError(t, err, "failed to save account settings") // two peers one added by a regular user and one with a setup key - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) + setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) if err != nil { t.Fatal("error creating setup key") return @@ -536,7 +540,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { assert.NotNil(t, peer) // delete the all-to-all policy so that user's peer1 has no access to peer2 - for _, policy := range account.Policies { + accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account policies") + + for _, policy := range accountPolicies { err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) if err != nil { t.Fatal(err) @@ -655,21 +662,33 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + + err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "") + require.NoError(t, err, "failed to create account") + + err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: someUser, + AccountID: accountID, Role: testCase.role, IsServiceUser: testCase.isServiceUser, - } - account.Policies = []*Policy{} - account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings + }) + require.NoError(t, err, "failed to create user") - err = manager.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) - return + accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account policies") + + for _, policy := range accountPolicies { + err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) + require.NoError(t, err, "failed to delete policy") } + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account settings") + + settings.RegularUsersViewBlocked = testCase.limitedViewSettings + err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings) + require.NoError(t, err, "failed to save account settings") + peerKey1, err := wgtypes.GeneratePrivateKey() if err != nil { t.Fatal(err) @@ -725,10 +744,18 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou adminUser := "account_creator" regularUser := "regular_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[regularUser] = &User{ - Id: regularUser, - Role: UserRoleUser, + err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "") + if err != nil { + return nil, "", "", err + } + + err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: regularUser, + AccountID: accountID, + Role: UserRoleUser, + }) + if err != nil { + return nil, "", "", err } // Create peers @@ -742,31 +769,40 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou Status: &nbpeer.PeerStatus{}, UserID: regularUser, } - account.Peers[peer.ID] = peer + err = manager.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer) + if err != nil { + return nil, "", "", err + } } // Create groups and policies - account.Policies = make([]*Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) group := &nbgroup.Group{ - ID: groupID, - Name: fmt.Sprintf("Group %d", i), + ID: groupID, + AccountID: accountID, + Name: fmt.Sprintf("Group %d", i), } for j := 0; j < peers/groups; j++ { peerIndex := i*(peers/groups) + j group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) } - account.Groups[groupID] = group + + err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + return nil, "", "", err + } // Create a policy for this group policy := &Policy{ - ID: fmt.Sprintf("policy-%d", i), - Name: fmt.Sprintf("Policy for Group %d", i), - Enabled: true, + ID: fmt.Sprintf("policy-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Policy for Group %d", i), + Enabled: true, Rules: []*PolicyRule{ { ID: fmt.Sprintf("rule-%d", i), + PolicyID: fmt.Sprintf("policy-%d", i), Name: fmt.Sprintf("Rule for Group %d", i), Enabled: true, Sources: []string{groupID}, @@ -777,22 +813,23 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou }, }, } - account.Policies = append(account.Policies, policy) + + err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + if err != nil { + return nil, "", "", err + } } - account.PostureChecks = []*posture.Checks{ - { - ID: "PostureChecksAll", - Name: "All", - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.0.1", - }, + err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, &posture.Checks{ + ID: "PostureChecksAll", + AccountID: accountID, + Name: "All", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", }, }, - } - - err = manager.Store.SaveAccount(context.Background(), account) + }) if err != nil { return nil, "", "", err } diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index e4b19da76..a135ad4af 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -41,6 +41,7 @@ type PersonalAccessToken struct { func (t *PersonalAccessToken) Copy() *PersonalAccessToken { return &PersonalAccessToken{ ID: t.ID, + UserID: t.UserID, Name: t.Name, HashedToken: t.HashedToken, ExpirationDate: t.ExpirationDate, diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 93e5741cf..aa99188eb 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -25,22 +25,22 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestPostureChecksAccount(am) + accountID, err := initTestPostureChecksAccount(am) if err != nil { t.Error("failed to init testing account") } t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) + _, err = am.SavePostureChecks(context.Background(), accountID, regularUserID, &posture.Checks{}) assert.Error(t, err) // regular users cannot list check - _, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID) + _, err = am.ListPostureChecks(context.Background(), accountID, regularUserID) assert.Error(t, err) // should be possible to create posture check with uniq name - postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ + postureCheck, err := am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -51,12 +51,12 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.NoError(t, err) // admin users can list check - checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID) + checks, err := am.ListPostureChecks(context.Background(), accountID, adminUserID) assert.NoError(t, err) assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ + _, err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ GeoLocationCheck: &posture.GeoLocationCheck{ @@ -76,45 +76,48 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { MinVersion: "0.27.0", }, } - _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck) + _, err = am.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheck) assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID) + err = am.DeletePostureChecks(context.Background(), accountID, postureCheck.ID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID) + err = am.DeletePostureChecks(context.Background(), accountID, postureCheck.ID, adminUserID) assert.NoError(t, err) - checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) + checks, err = am.ListPostureChecks(context.Background(), accountID, adminUserID) assert.NoError(t, err) assert.Len(t, checks, 0) }) } -func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { +func initTestPostureChecksAccount(am *DefaultAccountManager) (string, error) { accountID := "testingAccount" domain := "example.com" - admin := &User{ - Id: adminUserID, - Role: UserRoleAdmin, - } - user := &User{ - Id: regularUserID, - Role: UserRoleUser, - } - - account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) - account.Users[admin.Id] = admin - account.Users[user.Id] = user - - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain) if err != nil { - return nil, err + return "", err } - return am.Store.GetAccount(context.Background(), account.Id) + err = am.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{ + { + Id: adminUserID, + AccountID: accountID, + Role: UserRoleAdmin, + }, + { + Id: regularUserID, + AccountID: accountID, + Role: UserRoleUser, + }, + }) + if err != nil { + return "", err + } + + return accountID, nil } func TestPostureCheckAccountPeersUpdate(t *testing.T) { @@ -440,18 +443,18 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "failed to create account manager") - account, err := initTestPostureChecksAccount(manager) + accountID, err := initTestPostureChecksAccount(manager) require.NoError(t, err, "failed to init testing account") groupA := &group.Group{ ID: "groupA", - AccountID: account.Id, + AccountID: accountID, Peers: []string{"peer1"}, } groupB := &group.Group{ ID: "groupB", - AccountID: account.Id, + AccountID: accountID, Peers: []string{}, } err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) @@ -459,26 +462,26 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { postureCheckA := &posture.Checks{ Name: "checkA", - AccountID: account.Id, + AccountID: accountID, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, }, } - postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA) + postureCheckA, err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckA) require.NoError(t, err, "failed to save postureCheckA") postureCheckB := &posture.Checks{ Name: "checkB", - AccountID: account.Id, + AccountID: accountID, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, }, } - postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) + postureCheckB, err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckB) require.NoError(t, err, "failed to save postureCheckB") policy := &Policy{ - AccountID: account.Id, + AccountID: accountID, Rules: []*PolicyRule{ { Enabled: true, @@ -489,23 +492,23 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { SourcePostureChecks: []string{postureCheckA.ID}, } - policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + policy, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy) require.NoError(t, err, "failed to save policy") t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID) require.NoError(t, err) assert.True(t, result) }) t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckB.ID) require.NoError(t, err) assert.False(t, result) }) t.Run("posture check does not exist", func(t *testing.T) { - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown") + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, "unknown") require.NoError(t, err) assert.False(t, result) }) @@ -513,10 +516,10 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupA"} - _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + _, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy) require.NoError(t, err, "failed to update policy") - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID) require.NoError(t, err) assert.True(t, result) }) @@ -524,10 +527,10 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupB"} - _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + _, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy) require.NoError(t, err, "failed to update policy") - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID) require.NoError(t, err) assert.True(t, result) }) @@ -537,7 +540,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) require.NoError(t, err, "failed to save groups") - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID) require.NoError(t, err) assert.False(t, result) }) @@ -545,10 +548,10 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { policy.Rules[0].Sources = []string{"nonExistentGroup"} policy.Rules[0].Destinations = []string{"nonExistentGroup"} - _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + _, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy) require.NoError(t, err, "failed to update policy") - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID) require.NoError(t, err) assert.False(t, result) }) diff --git a/management/server/route_test.go b/management/server/route_test.go index 108f791e0..41a8a03ae 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -5,9 +5,11 @@ import ( "fmt" "net" "net/netip" + "strings" "testing" "time" + "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -427,22 +429,22 @@ func TestCreateRoute(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Errorf("failed to init testing account: %s", err) } if testCase.createInitRoute { - groupAll, errInit := account.GetGroupAll() + groupAll, errInit := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) + + _, errInit = am.CreateRoute(context.Background(), accountID, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), accountID, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) require.NoError(t, errInit) } - outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) - + outRoute, err := am.CreateRoute(context.Background(), accountID, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) testCase.errFunc(t, err) if !testCase.shouldCreate { @@ -451,6 +453,7 @@ func TestCreateRoute(t *testing.T) { // assign generated ID testCase.expectedRoute.ID = outRoute.ID + testCase.expectedRoute.AccountID = accountID if !testCase.expectedRoute.IsEqual(outRoute) { t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute) @@ -917,14 +920,15 @@ func TestSaveRoute(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Error("failed to init testing account") } if testCase.createInitRoute { - account.Routes["initRoute"] = &route.Route{ + initRoute := &route.Route{ ID: "initRoute", + AccountID: accountID, Network: existingNetwork, NetID: existingRouteID, NetworkType: route.IPv4Network, @@ -935,14 +939,13 @@ func TestSaveRoute(t *testing.T) { Enabled: true, Groups: []string{routeGroup1}, } + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, initRoute) + require.NoError(t, err, "failed to save init route") } - account.Routes[testCase.existingRoute.ID] = testCase.existingRoute - - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("account should be saved") - } + testCase.existingRoute.AccountID = accountID + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, testCase.existingRoute) + require.NoError(t, err, "failed to save existing route") var routeToSave *route.Route @@ -977,7 +980,7 @@ func TestSaveRoute(t *testing.T) { } } - err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave) + err = am.SaveRoute(context.Background(), accountID, userID, routeToSave) testCase.errFunc(t, err) @@ -985,14 +988,10 @@ func TestSaveRoute(t *testing.T) { return } - account, err = am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - } - - savedRoute, saved := account.Routes[testCase.expectedRoute.ID] - require.True(t, saved) + savedRoute, err := am.GetRoute(context.Background(), accountID, testCase.existingRoute.ID, userID) + require.NoError(t, err, "failed to get saved route") + testCase.expectedRoute.AccountID = accountID if !testCase.expectedRoute.IsEqual(savedRoute) { t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute) } @@ -1001,50 +1000,48 @@ func TestSaveRoute(t *testing.T) { } func TestDeleteRoute(t *testing.T) { - testingRoute := &route.Route{ - ID: "testingRoute", - Network: netip.MustParsePrefix("192.168.0.0/16"), - Domains: domain.List{"domain1", "domain2"}, - KeepRoute: true, - NetworkType: route.IPv4Network, - Peer: peer1Key, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - } - am, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Error("failed to init testing account") } - account.Routes[testingRoute.ID] = testingRoute + err = am.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ + ID: "GroupA", + AccountID: accountID, + Name: "GroupA", + }) + require.NoError(t, err, "failed to save group") - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("failed to save account") + testingRoute := &route.Route{ + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: route.NetID("12345678901234567890qw"), + Groups: []string{"GroupA"}, + KeepRoute: true, + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, } + createdRoute, err := am.CreateRoute(context.Background(), accountID, testingRoute.Network, testingRoute.NetworkType, testingRoute.Domains, peer1ID, []string{}, testingRoute.Description, testingRoute.NetID, testingRoute.Masquerade, testingRoute.Metric, testingRoute.Groups, testingRoute.AccessControlGroups, true, userID, testingRoute.KeepRoute) + require.NoError(t, err, "failed to create route") - err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID) + err = am.DeleteRoute(context.Background(), accountID, createdRoute.ID, userID) if err != nil { t.Error("deleting route failed with error: ", err) } - savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Error("failed to retrieve saved account with error: ", err) - } - - _, found := savedAccount.Routes[testingRoute.ID] - if found { - t.Error("route shouldn't be found after delete") - } + _, err = am.GetRoute(context.Background(), accountID, testingRoute.ID, userID) + require.NotNil(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) } func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { @@ -1066,16 +1063,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) - if err != nil { - t.Error("failed to init testing account") - } + accountID, err := initTestRouteAccount(t, am) + require.NoError(t, err, "failed to init testing account") newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) + newRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) @@ -1091,7 +1086,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { @@ -1103,21 +1098,21 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { } } - err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID) + err = am.GroupDeletePeer(context.Background(), accountID, groupHA1.ID, peer2ID) require.NoError(t, err) peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes") - err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID) + err = am.GroupDeletePeer(context.Background(), accountID, groupHA2.ID, peer4ID) require.NoError(t, err) peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") - err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID) + err = am.GroupAddPeer(context.Background(), accountID, groupHA2.ID, peer4ID) require.NoError(t, err) peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1128,7 +1123,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes") - err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID) + err = am.DeleteRoute(context.Background(), accountID, newRoute.ID, userID) require.NoError(t, err) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1158,7 +1153,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Error("failed to init testing account") } @@ -1167,7 +1162,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) + createdRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) require.NoError(t, err) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1181,7 +1176,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { expectedRoute := enabledRoute.Copy() expectedRoute.Peer = peer1Key - err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute) + err = am.SaveRoute(context.Background(), accountID, userID, enabledRoute) require.NoError(t, err) peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1193,7 +1188,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group") - err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID) + err = am.GroupAddPeer(context.Background(), accountID, routeGroup1, peer2ID) require.NoError(t, err) peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID) @@ -1206,10 +1201,10 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { Name: "peer1 group", Peers: []string{peer1ID}, } - err = am.SaveGroup(context.Background(), account.Id, userID, newGroup) + err = am.SaveGroup(context.Background(), accountID, userID, newGroup) require.NoError(t, err) - rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") + rules, err := am.ListPolicies(context.Background(), accountID, "testingUser") require.NoError(t, err) defaultRule := rules[0] @@ -1218,10 +1213,10 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) + _, err = am.SavePolicy(context.Background(), accountID, userID, newPolicy) require.NoError(t, err) - err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) + err = am.DeletePolicy(context.Background(), accountID, defaultRule.ID, userID) require.NoError(t, err) peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1232,7 +1227,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2") - err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID) + err = am.DeleteRoute(context.Background(), accountID, enabledRoute.ID, userID) require.NoError(t, err) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1266,179 +1261,104 @@ func createRouterStore(t *testing.T) (Store, error) { return store, nil } -func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (string, error) { t.Helper() accountID := "testingAcc" domain := "example.com" - account := newAccountWithId(context.Background(), accountID, userID, domain) - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain) if err != nil { - return nil, err + return "", err } - ips := account.getTakenIPs() - peer1IP, err := AllocatePeerIP(account.Network.Net, ips) + createPeer := func(peerID, peerKey, peerName, dnsLabel, kernel, core, platform, os string) (*nbpeer.Peer, error) { + ips, err := am.Store.GetTakenIPs(context.Background(), LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + network, err := am.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + peerIP, err := AllocatePeerIP(network.Net, ips) + if err != nil { + return nil, err + } + + peer := &nbpeer.Peer{ + IP: peerIP, + AccountID: accountID, + ID: peerID, + Key: peerKey, + Name: peerName, + DNSLabel: dnsLabel, + UserID: userID, + Meta: nbpeer.PeerSystemMeta{ + Hostname: peerName, + GoOS: strings.ToLower(kernel), + Kernel: kernel, + Core: core, + Platform: platform, + OS: os, + WtVersion: "development", + UIVersion: "development", + }, + Status: &nbpeer.PeerStatus{}, + } + if err := am.Store.AddPeerToAccount(context.Background(), peer); err != nil { + return nil, err + } + return peer, nil + } + + // Create peers + peer1, err := createPeer(peer1ID, peer1Key, "test-host1@netbird.io", "test-host1", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - - peer1 := &nbpeer.Peer{ - IP: peer1IP, - ID: peer1ID, - Key: peer1Key, - Name: "test-host1@netbird.io", - DNSLabel: "test-host1", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host1@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer1.ID] = peer1 - - ips = account.getTakenIPs() - peer2IP, err := AllocatePeerIP(account.Network.Net, ips) + peer2, err := createPeer(peer2ID, peer2Key, "test-host2@netbird.io", "test-host2", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - - peer2 := &nbpeer.Peer{ - IP: peer2IP, - ID: peer2ID, - Key: peer2Key, - Name: "test-host2@netbird.io", - DNSLabel: "test-host2", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host2@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer2.ID] = peer2 - - ips = account.getTakenIPs() - peer3IP, err := AllocatePeerIP(account.Network.Net, ips) + peer3, err := createPeer(peer3ID, peer3Key, "test-host3@netbird.io", "test-host3", "Darwin", "13.4.1", "arm64", "darwin") if err != nil { - return nil, err + return "", err } - - peer3 := &nbpeer.Peer{ - IP: peer3IP, - ID: peer3ID, - Key: peer3Key, - Name: "test-host3@netbird.io", - DNSLabel: "test-host3", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host3@netbird.io", - GoOS: "darwin", - Kernel: "Darwin", - Core: "13.4.1", - Platform: "arm64", - OS: "darwin", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer3.ID] = peer3 - - ips = account.getTakenIPs() - peer4IP, err := AllocatePeerIP(account.Network.Net, ips) + peer4, err := createPeer(peer4ID, peer4Key, "test-host4@netbird.io", "test-host4", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - - peer4 := &nbpeer.Peer{ - IP: peer4IP, - ID: peer4ID, - Key: peer4Key, - Name: "test-host4@netbird.io", - DNSLabel: "test-host4", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host4@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer4.ID] = peer4 - - ips = account.getTakenIPs() - peer5IP, err := AllocatePeerIP(account.Network.Net, ips) + peer5, err := createPeer(peer5ID, peer5Key, "test-host5@netbird.io", "test-host5", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - peer5 := &nbpeer.Peer{ - IP: peer5IP, - ID: peer5ID, - Key: peer5Key, - Name: "test-host5@netbird.io", - DNSLabel: "test-host5", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host5@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, + groupAll, err := am.GetGroupByName(context.Background(), "All", accountID) + if err != nil { + return "", err } - account.Peers[peer5.ID] = peer5 - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - return nil, err - } - groupAll, err := account.GetGroupAll() - if err != nil { - return nil, err - } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID) if err != nil { - return nil, err + return "", err } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID) if err != nil { - return nil, err + return "", err } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID) if err != nil { - return nil, err + return "", err } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID) if err != nil { - return nil, err + return "", err } - newGroup := []*nbgroup.Group{ + newGroups := []*nbgroup.Group{ { ID: routeGroup1, Name: routeGroup1, @@ -1470,15 +1390,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er Peers: []string{peer1.ID, peer4.ID}, }, } - - for _, group := range newGroup { - err = am.SaveGroup(context.Background(), accountID, userID, group) - if err != nil { - return nil, err - } + err = am.SaveGroups(context.Background(), accountID, userID, newGroups) + if err != nil { + return "", err } - return am.Store.GetAccount(context.Background(), account.Id) + return accountID, nil } func TestAccount_getPeersRoutesFirewall(t *testing.T) { @@ -1782,10 +1699,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) { manager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager") - account, err := initTestRouteAccount(t, manager) + accountID, err := initTestRouteAccount(t, manager) require.NoError(t, err, "failed to init testing account") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{ { ID: "groupA", Name: "GroupA", @@ -1831,7 +1748,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { }() _, err := manager.CreateRoute( - context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.Groups, []string{}, true, userID, route.KeepRoute, ) @@ -1867,7 +1784,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { }() _, err := manager.CreateRoute( - context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.Groups, []string{}, true, userID, route.KeepRoute, ) @@ -1903,7 +1820,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { }() newRoute, err := manager.CreateRoute( - context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, + context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, ) @@ -1927,7 +1844,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) + err := manager.SaveRoute(context.Background(), accountID, userID, &baseRoute) require.NoError(t, err) select { @@ -1945,7 +1862,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID) + err := manager.DeleteRoute(context.Background(), accountID, baseRoute.ID, userID) require.NoError(t, err) select { @@ -1969,7 +1886,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { Groups: []string{routeGroup1}, } _, err := manager.CreateRoute( - context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, ) @@ -1981,7 +1898,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, @@ -2009,7 +1926,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { Groups: []string{"groupC"}, } _, err := manager.CreateRoute( - context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, ) @@ -2021,7 +1938,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index ea239ec0c..4ef765a51 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -25,12 +25,12 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{ { ID: "group_1", Name: "group_name_1", @@ -49,7 +49,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { expiresIn := time.Hour keyName := "my-test-key" - key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, + key, err := manager.CreateSetupKey(context.Background(), accountID, keyName, SetupKeyReusable, expiresIn, []string{}, SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) @@ -58,7 +58,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { autoGroups := []string{"group_1", "group_2"} newKeyName := "my-new-test-key" revoked := true - newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + newKey, err := manager.SaveSetupKey(context.Background(), accountID, &SetupKey{ Id: key.Id, Name: newKeyName, Revoked: revoked, @@ -72,22 +72,22 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { key.Id, time.Now().UTC(), autoGroups, true) // check the corresponding events that should have been generated - ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked) + ev := getEvent(t, accountID, manager, activity.SetupKeyRevoked) assert.NotNil(t, ev) - assert.Equal(t, account.Id, ev.AccountID) + assert.Equal(t, accountID, ev.AccountID) assert.Equal(t, newKeyName, ev.Meta["name"]) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.NotEmpty(t, ev.Meta["key"]) assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, key.Id, ev.TargetID) - groupAll, err := account.GetGroupAll() + groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID) assert.NoError(t, err) // saving setup key with All group assigned to auto groups should return error autoGroups = append(autoGroups, groupAll.ID) - _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + _, err = manager.SaveSetupKey(context.Background(), accountID, &SetupKey{ Id: key.Id, Name: newKeyName, Revoked: revoked, @@ -103,12 +103,12 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -117,7 +117,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -126,7 +126,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - groupAll, err := account.GetGroupAll() + groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID) assert.NoError(t, err) type testCase struct { @@ -170,7 +170,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { for _, tCase := range []testCase{testCase1, testCase2, testCase3} { t.Run(tCase.name, func(t *testing.T) { - key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, + key, err := manager.CreateSetupKey(context.Background(), accountID, tCase.expectedKeyName, SetupKeyReusable, expiresIn, tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) if tCase.expectedFailure { @@ -189,10 +189,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated - ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated) + ev := getEvent(t, accountID, manager, activity.SetupKeyCreated) assert.NotNil(t, ev) - assert.Equal(t, account.Id, ev.AccountID) + assert.Equal(t, accountID, ev.AccountID) assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"]) assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"])) assert.NotEmpty(t, ev.Meta["key"]) @@ -208,12 +208,12 @@ func TestGetSetupKeys(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -222,7 +222,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 6d531ba2e..25eec7c50 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -261,7 +261,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error { return result.Error } - result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) + result = tx.Debug().Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) if result.Error != nil { return result.Error } diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 36c6eac32..88d5cd5e0 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -68,20 +68,27 @@ func TestSqlite_SaveAccount_Large(t *testing.T) { func runLargeTest(t *testing.T, store Store) { t.Helper() - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - groupALL, err := account.GetGroupAll() - if err != nil { - t.Fatal(err) - } + accountID := "account_id" + + err := newAccountWithId(context.Background(), store, accountID, "testuser", "") + require.NoError(t, err, "failed to create account") + + groupAll, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") + require.NoError(t, err, "failed to get group All") + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + assert.NoError(t, err, "failed to save setup key") + const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { netIP := randomIPv4() - peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) + peerID := fmt.Sprintf("%s-peer-%d", accountID, n) peer := &nbpeer.Peer{ ID: peerID, + AccountID: accountID, Key: peerID, IP: netIP, Name: peerID, @@ -90,16 +97,21 @@ func runLargeTest(t *testing.T, store Store) { Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, SSHEnabled: false, } - account.Peers[peerID] = peer - group, _ := account.GetGroupAll() - group.Peers = append(group.Peers, peerID) - user := &User{ - Id: fmt.Sprintf("%s-user-%d", account.Id, n), - AccountID: account.Id, - } - account.Users[user.Id] = user + err = store.AddPeerToAccount(context.Background(), peer) + require.NoError(t, err, "failed to add peer") + + err = store.AddPeerToAllGroup(context.Background(), accountID, peerID) + require.NoError(t, err, "failed to add peer to all group") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: fmt.Sprintf("%s-user-%d", accountID, n), + AccountID: accountID, + }) + require.NoError(t, err, "failed to save user") + route := &route2.Route{ ID: route2.ID(fmt.Sprintf("network-id-%d", n)), + AccountID: accountID, Description: "base route", NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)), Network: netip.MustParsePrefix(netIP.String() + "/24"), @@ -107,22 +119,24 @@ func runLargeTest(t *testing.T, store Store) { Metric: 9999, Masquerade: false, Enabled: true, - Groups: []string{groupALL.ID}, + Groups: []string{groupAll.ID}, } - account.Routes[route.ID] = route + err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route) + require.NoError(t, err, "failed to save route") - group = &nbgroup.Group{ + group := &nbgroup.Group{ ID: fmt.Sprintf("group-id-%d", n), - AccountID: account.Id, + AccountID: accountID, Name: fmt.Sprintf("group-id-%d", n), Issued: "api", Peers: nil, } - account.Groups[group.ID] = group + err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + require.NoError(t, err, "failed to save group") nameserver := &nbdns.NameServerGroup{ ID: fmt.Sprintf("nameserver-id-%d", n), - AccountID: account.Id, + AccountID: accountID, Name: fmt.Sprintf("nameserver-id-%d", n), Description: "", NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}}, @@ -132,20 +146,20 @@ func runLargeTest(t *testing.T, store Store) { Enabled: false, SearchDomainsEnabled: false, } - account.NameServerGroups[nameserver.ID] = nameserver + err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameserver) + require.NoError(t, err, "failed to save nameserver group") - setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey + setupKey, _ = GenerateDefaultSetupKey() + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") } - err = store.SaveAccount(context.Background(), account) + totalAccounts, err := store.GetTotalAccounts(context.Background()) require.NoError(t, err) + require.Equal(t, int64(1), totalAccounts, "expected 1 account") - if len(store.GetAllAccounts(context.Background())) != 1 { - t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") - } - - a, err := store.GetAccount(context.Background(), account.Id) + a, err := store.GetAccount(context.Background(), accountID) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -213,41 +227,53 @@ func TestSqlite_SaveAccount(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) t.Cleanup(cleanUp) - assert.NoError(t, err) + require.NoError(t, err) + + accountID := "account_id" + err = newAccountWithId(context.Background(), store, accountID, "testuser", "") + require.NoError(t, err, "failed to create account") - account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer", + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + AccountID: accountID, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + }) + require.NoError(t, err, "failed to save peer") + + accountID2 := "account_id2" + err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "") + require.NoError(t, err, "failed to create account") - account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") setupKey, _ = GenerateDefaultSetupKey() - account2.SetupKeys[setupKey.Key] = setupKey - account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + setupKey.AccountID = accountID2 + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") - err = store.SaveAccount(context.Background(), account2) - require.NoError(t, err) + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer2", + Key: "peerkey2", + AccountID: accountID2, + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + }) + require.NoError(t, err, "failed to save peer") if len(store.GetAllAccounts(context.Background())) != 2 { t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(context.Background(), account.Id) + a, err := store.GetAccount(context.Background(), accountID) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -295,7 +321,12 @@ func TestSqlite_DeleteAccount(t *testing.T) { Name: "test token", }} - account := newAccountWithId(context.Background(), "account_id", testUserID, "") + err = newAccountWithId(context.Background(), store, "account_id", testUserID, "") + require.NoError(t, err, "failed to create account") + + account, err := store.GetAccount(context.Background(), "account_id") + require.NoError(t, err, "failed to get account") + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -685,19 +716,29 @@ func newSqliteStore(t *testing.T) *SqlStore { } func newAccount(store Store, id int) error { - str := fmt.Sprintf("%s-%d", uuid.New().String(), id) - account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") + accountID := fmt.Sprintf("%s-%d", uuid.New().String(), id) + userID := accountID + "-testuser" + + err := newAccountWithId(context.Background(), store, accountID, userID, "example.com") + if err != nil { + return err + } + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["p"+str] = &nbpeer.Peer{ - Key: "peerkey" + str, + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + if err != nil { + return err + } + + return store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ + ID: "p" + accountID, + Key: accountID + "-peerkey", IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - - return store.SaveAccount(context.Background(), account) + }) } func TestPostgresql_NewStore(t *testing.T) { @@ -725,39 +766,53 @@ func TestPostgresql_SaveAccount(t *testing.T) { t.Cleanup(cleanUp) assert.NoError(t, err) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") + accountID := "account_id" + + err = newAccountWithId(context.Background(), store, accountID, "testuser", "") + require.NoError(t, err, "failed to create account") + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer", + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + AccountID: accountID, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + }) + require.NoError(t, err, "failed to save peer") + + accountID2 := "account_id2" + + err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "") + require.NoError(t, err, "failed to create account") - account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") setupKey, _ = GenerateDefaultSetupKey() - account2.SetupKeys[setupKey.Key] = setupKey - account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + setupKey.AccountID = accountID2 + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") - err = store.SaveAccount(context.Background(), account2) - require.NoError(t, err) + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer2", + Key: "peerkey2", + AccountID: accountID2, + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + }) + require.NoError(t, err, "failed to save peer") - if len(store.GetAllAccounts(context.Background())) != 2 { - t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") - } + totalAccounts, err := store.GetTotalAccounts(context.Background()) + require.NoError(t, err, "failed to get total accounts") + require.Equal(t, int64(2), totalAccounts, "expecting 2 Accounts to be stored after SaveAccount()") - a, err := store.GetAccount(context.Background(), account.Id) + a, err := store.GetAccount(context.Background(), accountID) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -798,31 +853,41 @@ func TestPostgresql_DeleteAccount(t *testing.T) { t.Cleanup(cleanUp) assert.NoError(t, err) + accountID := "account_id" testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { - ID: "testtoken", - Name: "test token", - }} - account := newAccountWithId(context.Background(), "account_id", testUserID, "") + err = newAccountWithId(context.Background(), store, accountID, testUserID, "") + require.NoError(t, err, "failed to create account") + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - account.Users[testUserID] = user + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testingpeer", + AccountID: accountID, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + }) + require.NoError(t, err, "failed to save peer") - if len(store.GetAllAccounts(context.Background())) != 1 { - t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") - } + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: "testtoken", + UserID: testUserID, + Name: "test token", + }) + require.NoError(t, err, "failed to save personal access token") + + totalAccounts, err := store.GetTotalAccounts(context.Background()) + require.NoError(t, err, "failed to get total accounts") + require.Equal(t, int64(1), totalAccounts, "expecting 1 Accounts to be stored after SaveAccount()") + + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "failed to get account") err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) @@ -1172,7 +1237,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { domain := "example.com" category := "public" IsDomainPrimaryAccount := false - err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount) require.NoError(t, err) account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) @@ -1186,7 +1251,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { domain := "test.com" category := "private" IsDomainPrimaryAccount := true - err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount) require.NoError(t, err) account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) @@ -1200,7 +1265,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { domain := "test.com" category := "private" IsDomainPrimaryAccount := true - err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) + err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, "non-existing-account-id", domain, category, &IsDomainPrimaryAccount) require.Error(t, err) }) @@ -2668,7 +2733,7 @@ func TestSqlStore_GetTotalAccounts(t *testing.T) { t.Cleanup(cleanup) require.NoError(t, err) - totalAccounts, err := store.GetTotalAccounts(context.Background(), LockingStrengthShare) + totalAccounts, err := store.GetTotalAccounts(context.Background()) require.NoError(t, err) require.Equal(t, int64(1), totalAccounts) } diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index 54b946b5a..64e47ff69 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -33,4 +33,7 @@ INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-3465300 INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfvprsrlo1hqoo49ohog", "cg3161rlo1hs9cq94gdg", "cg05lnblo1hkg2j514p0"]',0,''); +INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); +INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); INSERT INTO installations VALUES(1,''); diff --git a/management/server/user_test.go b/management/server/user_test.go index 2f8c1bf70..cd43aab31 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -43,37 +43,34 @@ const ( func TestUser_CreatePAT_ForSameUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, eventStore: &activity.InMemoryEventStore{}, } - pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) + newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } - assert.Equal(t, pat.CreatedBy, mockUserID) + assert.Equal(t, newPAT.CreatedBy, mockUserID) - tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken) + pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken) if err != nil { t.Fatalf("Error when getting token ID by hashed token: %s", err) } - if tokenID == "" { + if pat.ID == "" { t.Fatal("GetTokenIDByHashedToken failed after adding PAT") } - assert.Equal(t, pat.ID, tokenID) + assert.Equal(t, newPAT.ID, pat.ID) - user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, tokenID) + user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID) if err != nil { t.Fatalf("Error when getting user by token ID: %s", err) } @@ -84,15 +81,16 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockTargetUserId, + AccountID: mockAccountID, IsServiceUser: false, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -106,15 +104,16 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { func TestUser_CreatePAT_ForServiceUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockTargetUserId, + AccountID: mockAccountID, IsServiceUser: true, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -132,12 +131,9 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -151,12 +147,9 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { func TestUser_CreatePAT_WithEmptyName(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -164,26 +157,22 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { } _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) - assert.Errorf(t, err, "Wrong expiration should thorw error") + assert.Errorf(t, err, "Wrong expiration should throw error") } func TestUser_DeletePAT(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ - Id: mockUserID, - PATs: map[string]*PersonalAccessToken{ - mockTokenID1: { - ID: mockTokenID1, - HashedToken: mockToken1, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID1, + UserID: mockUserID, + HashedToken: mockToken1, + }) + assert.NoError(t, err, "failed to create PAT") am := DefaultAccountManager{ Store: store, @@ -195,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) { t.Fatalf("Error when adding PAT to user: %s", err) } - account, err = store.GetAccount(context.Background(), mockAccountID) + account, err := store.GetAccount(context.Background(), mockAccountID) if err != nil { t.Fatalf("Error when getting account: %s", err) } @@ -206,21 +195,16 @@ func TestUser_DeletePAT(t *testing.T) { func TestUser_GetPAT(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ - Id: mockUserID, - AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ - mockTokenID1: { - ID: mockTokenID1, - HashedToken: mockToken1, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID1, + UserID: mockUserID, + HashedToken: mockToken1, + }) + assert.NoError(t, err, "failed to create PAT") am := DefaultAccountManager{ Store: store, @@ -239,25 +223,23 @@ func TestUser_GetPAT(t *testing.T) { func TestUser_GetAllPATs(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ - Id: mockUserID, - AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ - mockTokenID1: { - ID: mockTokenID1, - HashedToken: mockToken1, - }, - mockTokenID2: { - ID: mockTokenID2, - HashedToken: mockToken2, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID1, + UserID: mockUserID, + HashedToken: mockToken1, + }) + assert.NoError(t, err, "failed to create PAT") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID2, + UserID: mockUserID, + HashedToken: mockToken2, + }) + assert.NoError(t, err, "failed to create PAT") am := DefaultAccountManager{ Store: store, @@ -342,12 +324,9 @@ func validateStruct(s interface{}) (err error) { func TestUser_CreateServiceUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -359,7 +338,7 @@ func TestUser_CreateServiceUser(t *testing.T) { t.Fatalf("Error when creating service user: %s", err) } - account, err = store.GetAccount(context.Background(), mockAccountID) + account, err := store.GetAccount(context.Background(), mockAccountID) assert.NoError(t, err) assert.Equal(t, 2, len(account.Users)) @@ -383,12 +362,9 @@ func TestUser_CreateServiceUser(t *testing.T) { func TestUser_CreateUser_ServiceUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -406,7 +382,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { t.Fatalf("Error when creating user: %s", err) } - account, err = store.GetAccount(context.Background(), mockAccountID) + account, err := store.GetAccount(context.Background(), mockAccountID) assert.NoError(t, err) assert.True(t, user.IsServiceUser) @@ -425,12 +401,9 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { func TestUser_CreateUser_RegularUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -450,12 +423,9 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { func TestUser_InviteNewUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -549,13 +519,13 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = tt.serviceUser - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + tt.serviceUser.AccountID = mockAccountID + err = store.SaveUser(context.Background(), LockingStrengthUpdate, tt.serviceUser) + assert.NoError(t, err, "failed to create service user") am := DefaultAccountManager{ Store: store, @@ -582,12 +552,9 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { func TestUser_DeleteUser_SelfDelete(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -603,39 +570,38 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { func TestUser_DeleteUser_regularUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - targetId := "user2" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: true, - ServiceUserName: "user2username", - } - targetId = "user3" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - } - targetId = "user4" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedIntegration, - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") - targetId = "user5" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{ + { + Id: "user2", + AccountID: mockAccountID, + IsServiceUser: true, + ServiceUserName: "user2username", + }, + { + Id: "user3", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user4", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedIntegration, + }, + { + Id: "user5", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleOwner, + }, + }) + assert.NoError(t, err, "failed to save users") am := DefaultAccountManager{ Store: store, @@ -685,61 +651,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { func TestUser_DeleteUser_RegularUsers(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - targetId := "user2" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: true, - ServiceUserName: "user2username", - } - targetId = "user3" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - } - targetId = "user4" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedIntegration, - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") - targetId = "user5" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, - } - account.Users["user6"] = &User{ - Id: "user6", - IsServiceUser: false, - Issued: UserIssuedAPI, - } - account.Users["user7"] = &User{ - Id: "user7", - IsServiceUser: false, - Issued: UserIssuedAPI, - } - account.Users["user8"] = &User{ - Id: "user8", - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, - } - account.Users["user9"] = &User{ - Id: "user9", - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{ + { + Id: "user2", + AccountID: mockAccountID, + IsServiceUser: true, + ServiceUserName: "user2username", + }, + { + Id: "user3", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user4", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedIntegration, + }, + { + Id: "user5", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleOwner, + }, + { + Id: "user6", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user7", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user8", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleAdmin, + }, + { + Id: "user9", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleAdmin, + }, + }) + assert.NoError(t, err) am := DefaultAccountManager{ Store: store, @@ -816,7 +785,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - acc, err := am.Store.GetAccount(context.Background(), account.Id) + acc, err := am.Store.GetAccount(context.Background(), mockAccountID) assert.NoError(t, err) for _, id := range tc.expectedDeleted { @@ -836,12 +805,9 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -865,14 +831,19 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_ListUsers(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewRegularUser("normal_user1") - account.Users["normal_user2"] = NewRegularUser("normal_user2") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + newUser := NewRegularUser("normal_user1") + newUser.AccountID = mockAccountID + err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser) + assert.NoError(t, err, "failed to create user") + + newUser = NewRegularUser("normal_user2") + newUser.AccountID = mockAccountID + err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -946,15 +917,25 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) - account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings - delete(account.Users, mockUserID) - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + newUser := NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + newUser.AccountID = mockAccountID + err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser) + assert.NoError(t, err, "failed to create user") + + settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, mockAccountID) + assert.NoError(t, err, "failed to get account settings") + + settings.RegularUsersViewBlocked = testCase.limitedViewSettings + + err = store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, mockAccountID, settings) + assert.NoError(t, err, "failed to save account settings") + + err = store.DeleteUser(context.Background(), LockingStrengthUpdate, mockAccountID, mockUserID) + assert.NoError(t, err, "failed to delete user") am := DefaultAccountManager{ Store: store, @@ -968,7 +949,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { assert.Equal(t, 1, len(users)) - userInfo, _ := users[0].ToUserInfo(nil, account.Settings) + userInfo, _ := users[0].ToUserInfo(nil, settings) assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) }) } @@ -978,22 +959,21 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { func TestDefaultAccountManager_ExternalCache(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - externalUser := &User{ - Id: "externalUser", - Role: UserRoleUser, - Issued: UserIssuedIntegration, + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: "externalUser", + AccountID: mockAccountID, + Role: UserRoleUser, + Issued: UserIssuedIntegration, IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", }, - } - account.Users[externalUser.Id] = externalUser - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -1013,6 +993,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { assert.NoError(t, err) cacheManager := am.GetExternalCacheManager() + + externalUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, "externalUser") + assert.NoError(t, err, "failed to get user") + cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id) err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}) assert.NoError(t, err) @@ -1042,17 +1026,17 @@ func TestUser_IsAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockServiceUserID, + AccountID: mockAccountID, Role: "user", IsServiceUser: true, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -1071,17 +1055,16 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockServiceUserID, + AccountID: mockAccountID, Role: "user", IsServiceUser: true, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -1240,21 +1223,30 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // create an account and an admin user - account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), ownerUserID, "netbird.io") if err != nil { t.Fatal(err) } // create other users - account.Users[regularUserID] = NewRegularUser(regularUserID) - account.Users[adminUserID] = NewAdminUser(adminUserID) - account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} - err = manager.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) + regularUser := NewRegularUser(regularUserID) + regularUser.AccountID = accountID + + adminUser := NewAdminUser(adminUserID) + adminUser.AccountID = accountID + + serviceUser := &User{ + Id: serviceUserID, + AccountID: accountID, + IsServiceUser: true, + Role: UserRoleAdmin, + ServiceUserName: "service", } - updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update) + err = manager.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{regularUser, adminUser, serviceUser}) + assert.NoError(t, err, "failed to save users") + + updated, err := manager.SaveUser(context.Background(), accountID, tc.initiatorID, tc.update) if tc.expectedErr { require.Errorf(t, err, "expecting SaveUser to throw an error") } else {