fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-06 15:59:12 +03:00
parent e513e51e9f
commit 9cb7336ef5
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
18 changed files with 1244 additions and 1132 deletions

View File

@ -28,9 +28,12 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');
COMMIT; COMMIT;

View File

@ -817,15 +817,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
return groupList return groupList
} }
func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
takenIps = append(takenIps, existingPeer.IP)
}
return takenIps
}
func (a *Account) getPeerDNSLabels() lookupMap { func (a *Account) getPeerDNSLabels() lookupMap {
existingLabels := make(lookupMap) existingLabels := make(lookupMap)
for _, peer := range a.Peers { for _, peer := range a.Peers {
@ -1147,8 +1138,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.updateAccountPeers(ctx, accountID)
return newSettings, nil return newSettings, nil
} }

View File

@ -398,7 +398,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
} }
for _, testCase := range tt { for _, testCase := range tt {
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io") store := newStore(t)
err := newAccountWithId(context.Background(), store, "account-1", userID, "netbird.io")
require.NoError(t, err, "failed to create account")
account, err := store.GetAccount(context.Background(), "account-1")
require.NoError(t, err, "failed to get account")
account.UpdateSettings(&testCase.accountSettings) account.UpdateSettings(&testCase.accountSettings)
account.Network = network account.Network = network
account.Peers = testCase.peers account.Peers = testCase.peers
@ -416,6 +423,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
store.Close(context.Background())
} }
} }
@ -423,7 +432,15 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io" domain := "netbird.io"
userId := "account_creator" userId := "account_creator"
accountID := "account_id" accountID := "account_id"
account := newAccountWithId(context.Background(), accountID, userId, domain)
store := newStore(t)
defer store.Close(context.Background())
err := newAccountWithId(context.Background(), store, accountID, userId, domain)
require.NoError(t, err, "failed to create account")
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
} }
@ -434,16 +451,16 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return return
} }
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if account == nil { if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userID) t.Fatalf("expected to create an account for a user %s", userID)
return return
} }
account, err = manager.Store.GetAccountByUser(context.Background(), userID) account, err := manager.Store.GetAccountByUser(context.Background(), userID)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
return return
@ -666,15 +683,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id" userId := "user-id"
domain := "test.domain" domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountIDByUserID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err := manager.Store.GetAccount(context.Background(), accountID) initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed") require.NoError(t, err, "get init account failed")
@ -690,44 +704,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID) accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "get account failed") require.NoError(t, err, "failed to get account groups")
require.Len(t, account.Groups, 1, "only ALL group should exists") require.Len(t, accountGroups, 1, "only ALL group should exists")
}) })
t.Run("JWT groups enabled without claim name", func(t *testing.T) { t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(context.Background(), initAccount) _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
require.NoError(t, err, "save account failed") require.NoError(t, err, "failed to update account settings")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
accountIDs, err := manager.Store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get account ids")
require.Len(t, accountIDs, 1, "only one account should exist")
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID) accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "get account failed") require.NoError(t, err, "failed to get account groups")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") require.Len(t, accountGroups, 1, "if group claim is not set no group added from JWT")
}) })
t.Run("JWT groups enabled", func(t *testing.T) { t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups" initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(context.Background(), initAccount) _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
require.NoError(t, err, "save account failed") require.NoError(t, err, "failed to update account settings")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
accountIDs, err := manager.Store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get account ids")
require.Len(t, accountIDs, 1, "only one account should exist")
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID) exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "get account failed") require.NoError(t, err, "failed to check account existence")
require.True(t, exists, "account should exist")
require.Len(t, account.Groups, 3, "groups should be added to the account") accountGroups, err := manager.GetAllGroups(context.Background(), accountID, userId)
require.NoError(t, err, "failed to get account groups")
require.Len(t, accountGroups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{} groupsByNames := map[string]*group.Group{}
for _, g := range account.Groups { for _, g := range accountGroups {
groupsByNames[g.Name] = g groupsByNames[g.Name] = g
} }
@ -745,60 +768,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
func TestAccountManager_GetAccountFromPAT(t *testing.T) { func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), "account_id", "testuser", "") err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
require.NoError(t, err, "failed to create account")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token)) hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser", userPAT := &PersonalAccessToken{
PATs: map[string]*PersonalAccessToken{ ID: "tokenId",
"tokenId": { UserID: "testuser",
ID: "tokenId", HashedToken: encodedHashedToken,
HashedToken: encodedHashedToken, CreatedAt: time.Now().UTC(),
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
} }
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
require.NoError(t, err, "failed to save PAT")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
} }
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token) user, pat, _, _, err := am.GetAccountInfoFromPAT(context.Background(), token)
if err != nil { if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err) t.Fatalf("Error when getting Account from PAT: %s", err)
} }
assert.Equal(t, "account_id", account.Id) assert.Equal(t, "account_id", user.AccountID)
assert.Equal(t, "someUser", user.Id) assert.Equal(t, "testuser", user.Id)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) assert.Equal(t, userPAT, pat)
} }
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), "account_id", "testuser", "") err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
require.NoError(t, err, "failed to create account")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token)) hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser", userPAT := &PersonalAccessToken{
PATs: map[string]*PersonalAccessToken{ ID: "tokenId",
"tokenId": { UserID: "someUser",
ID: "tokenId", HashedToken: encodedHashedToken,
HashedToken: encodedHashedToken, LastUsed: time.Time{},
LastUsed: time.Time{},
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
} }
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
require.NoError(t, err, "failed to save PAT")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -809,11 +825,10 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
t.Fatalf("Error when marking PAT used: %s", err) t.Fatalf("Error when marking PAT used: %s", err)
} }
account, err = am.Store.GetAccount(context.Background(), "account_id") userPAT, err = store.GetPATByID(context.Background(), LockingStrengthShare, userPAT.UserID, userPAT.ID)
if err != nil { require.NoError(t, err, "failed to get PAT")
t.Fatalf("Error when getting account: %s", err)
} assert.True(t, !userPAT.LastUsed.IsZero())
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
} }
func TestAccountManager_PrivateAccount(t *testing.T) { func TestAccountManager_PrivateAccount(t *testing.T) {
@ -824,15 +839,15 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
} }
userId := "test_user" userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if account == nil { if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
account, err = manager.Store.GetAccountByUser(context.Background(), userId) account, err := manager.Store.GetAccountByUser(context.Background(), userId)
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
} }
@ -851,32 +866,22 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
userId := "test_user" userId := "test_user"
domain := "hotmail.com" domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain) accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
if err != nil { require.NoError(t, err, "failed to get or create account by user")
t.Fatal(err) require.NotEmptyf(t, accountID, "expected to create an account for a user %s", userId)
}
if account == nil {
t.Fatalf("expected to create an account for a user %s", userId)
}
if account != nil && account.Domain != domain { accDomain, _, err := manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) require.NoError(t, err, "failed to get account domain and category")
} require.Equal(t, domain, accDomain, "expected account domain to match")
domain = "gmail.com" domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain) accountID, err = manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
if err != nil { require.NoError(t, err, "failed to get or create account by user")
t.Fatalf("got the following error while retrieving existing acc: %v", err)
}
if account == nil { accDomain, _, err = manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
t.Fatalf("expected to get an account for a user %s", userId) require.NoError(t, err, "failed to get account domain and category")
} require.Equal(t, domain, accDomain, "expected account domain to match")
if account != nil && account.Domain != domain {
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
}
} }
func TestAccountManager_GetAccountByUserID(t *testing.T) { func TestAccountManager_GetAccountByUserID(t *testing.T) {
@ -908,12 +913,11 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
} }
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
account := newAccountWithId(context.Background(), accountID, userID, domain) err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return account, nil return am.Store.GetAccount(context.Background(), accountID)
} }
func TestAccountManager_GetAccount(t *testing.T) { func TestAccountManager_GetAccount(t *testing.T) {
@ -1056,23 +1060,18 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return return
} }
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "netbird.cloud")
if err != nil { require.NoError(t, err, "failed to get or create account by user")
t.Fatal(err)
}
serial := account.Network.CurrentSerial() // should be 0 network, err := manager.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account network")
if account.Network.Serial != 0 { serial := network.CurrentSerial() // should be 0
t.Errorf("expecting account network to have an initial Serial=0") require.Equal(t, 0, int(serial), "expected account network to have an initial Serial=0")
return
}
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { require.NoError(t, err, "failed to generate private key")
t.Fatal(err)
return
}
expectedPeerKey := key.PublicKey().String() expectedPeerKey := key.PublicKey().String()
expectedUserID := userID expectedUserID := userID
@ -1080,16 +1079,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
Key: expectedPeerKey, Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}) })
if err != nil { require.NoError(t, err, "failed to add peer")
t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy)
return
}
account, err = manager.Store.GetAccount(context.Background(), account.Id) account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil { require.NoError(t, err, "failed to get account")
t.Fatal(err)
return
}
if peer.Key != expectedPeerKey { if peer.Key != expectedPeerKey {
t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key) t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key)
@ -1215,10 +1208,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
policyID := xid.New().String()
policy := Policy{ policy := Policy{
Enabled: true, ID: policyID,
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "rule",
PolicyID: policyID,
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@ -1252,19 +1250,25 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
manager, account, peer1, _, peer3 := setupNetworkMapTest(t) manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
group := group.Group{ group := group.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", AccountID: account.Id,
Peers: []string{peer1.ID, peer3.ID}, Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
} }
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err) t.Errorf("save group: %v", err)
return return
} }
policyID := xid.New().String()
policy := Policy{ policy := Policy{
Enabled: true, ID: policyID,
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "rule",
PolicyID: policyID,
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@ -1305,19 +1309,24 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{ group := group.Group{
ID: "groupA", ID: "groupA",
Name: "GroupA", AccountID: account.Id,
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
} }
err := manager.SaveGroup(context.Background(), account.Id, userID, &group)
require.NoError(t, err, "failed to save group")
policyID := xid.New().String()
policy := Policy{ policy := Policy{
Enabled: true, ID: policyID,
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "rule",
PolicyID: policyID,
Enabled: true, Enabled: true,
Sources: []string{"groupA"}, Sources: []string{"groupA"},
Destinations: []string{"groupA"}, Destinations: []string{"groupA"},
@ -1327,6 +1336,9 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
}, },
} }
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err) t.Errorf("delete default rule: %v", err)
return return
@ -1355,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return return
} }
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { if err := manager.DeleteGroup(context.Background(), account.Id, userID, group.ID); err != nil {
t.Errorf("delete group: %v", err) t.Errorf("delete group: %v", err)
return return
} }
@ -1754,12 +1766,6 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(2) wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{ manager.peerLoginExpiry = &MockScheduler{
@ -1865,10 +1871,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}, },
} }
// enabling PeerLoginExpirationEnabled should trigger the expiration job // enabling PeerLoginExpirationEnabled should trigger the expiration job
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
PeerLoginExpiration: time.Hour, require.NoError(t, err, "failed to get account settings")
PeerLoginExpirationEnabled: true,
}) settings.PeerLoginExpirationEnabled = true
settings.PeerLoginExpiration = time.Hour
settings, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@ -1878,10 +1886,8 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
wg.Add(1) wg.Add(1)
// disabling PeerLoginExpirationEnabled should trigger cancel // disabling PeerLoginExpirationEnabled should trigger cancel
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings.PeerLoginExpirationEnabled = false
PeerLoginExpiration: time.Hour, _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
PeerLoginExpirationEnabled: false,
})
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
failed = waitTimeout(wg, time.Second) failed = waitTimeout(wg, time.Second)
if failed { if failed {
@ -1896,30 +1902,29 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
PeerLoginExpiration: time.Hour, require.NoError(t, err, "unable to get account settings")
PeerLoginExpirationEnabled: false,
}) settings.PeerLoginExpirationEnabled = false
settings.PeerLoginExpiration = time.Hour
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updatedSettings.PeerLoginExpirationEnabled) assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) settings, err = manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled) assert.False(t, settings.PeerLoginExpirationEnabled)
assert.Equal(t, settings.PeerLoginExpiration, time.Hour) assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings.PeerLoginExpirationEnabled = false
PeerLoginExpiration: time.Second, settings.PeerLoginExpiration = time.Second
PeerLoginExpirationEnabled: false, _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ settings.PeerLoginExpirationEnabled = false
PeerLoginExpiration: time.Hour * 24 * 181, settings.PeerLoginExpiration = time.Hour * 24 * 181
PeerLoginExpirationEnabled: false, _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
} }

View File

@ -39,12 +39,12 @@ func TestGetDNSSettings(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestDNSAccount(t, am) accountID, err := initTestDNSAccount(t, am)
if err != nil { if err != nil {
t.Fatal("failed to init testing account") t.Fatal("failed to init testing account")
} }
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) dnsSettings, err := am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
if err != nil { if err != nil {
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
} }
@ -53,16 +53,12 @@ func TestGetDNSSettings(t *testing.T) {
t.Fatal("DNS settings for new accounts shouldn't return nil") t.Fatal("DNS settings for new accounts shouldn't return nil")
} }
account.DNSSettings = DNSSettings{ err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &DNSSettings{
DisabledManagementGroups: []string{group1ID}, DisabledManagementGroups: []string{group1ID},
} })
require.NoError(t, err, "failed to update DNS settings")
err = am.Store.SaveAccount(context.Background(), account) dnsSettings, err = am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
if err != nil {
t.Error("failed to save testing account with new DNS settings")
}
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
if err != nil { if err != nil {
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
} }
@ -71,7 +67,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
} }
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID) _, err = am.GetDNSSettings(context.Background(), accountID, dnsRegularUserID)
if err == nil { if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user") t.Errorf("An error should be returned when getting the DNS settings with a regular user")
} }
@ -126,12 +122,12 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestDNSAccount(t, am) accountID, err := initTestDNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) err = am.SaveDNSSettings(context.Background(), accountID, testCase.userID, testCase.inputSettings)
if err != nil { if err != nil {
if testCase.shouldFail { if testCase.shouldFail {
return return
@ -139,7 +135,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error(err) t.Error(err)
} }
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id) updatedAccount, err := am.Store.GetAccount(context.Background(), accountID)
if err != nil { if err != nil {
t.Errorf("should be able to retrieve updated account, got err: %s", err) t.Errorf("should be able to retrieve updated account, got err: %s", err)
} }
@ -158,17 +154,17 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestDNSAccount(t, am) accountID, err := initTestDNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
peer1, err := account.FindPeerByPubKey(dnsPeer1Key) peer1, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer1Key)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
peer2, err := account.FindPeerByPubKey(dnsPeer2Key) peer2, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer2Key)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
@ -179,11 +175,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group") require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
dnsSettings := account.DNSSettings.Copy() accountDNSSettings, err := am.Store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account DNS settings")
dnsSettings := accountDNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
account.DNSSettings = dnsSettings err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &dnsSettings)
err = am.Store.SaveAccount(context.Background(), account) require.NoError(t, err, "failed to update DNS settings")
require.NoError(t, err)
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
require.NoError(t, err) require.NoError(t, err)
@ -222,7 +220,7 @@ func createDNSStore(t *testing.T) (Store, error) {
return store, nil return store, nil
} }
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper() t.Helper()
peer1 := &nbpeer.Peer{ peer1 := &nbpeer.Peer{
Key: dnsPeer1Key, Key: dnsPeer1Key,
@ -257,64 +255,65 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
domain := "example.com" domain := "example.com"
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) err := newAccountWithId(context.Background(), am.Store, dnsAccountID, dnsAdminUserID, domain)
if err != nil {
account.Users[dnsRegularUserID] = &User{ return "", err
Id: dnsRegularUserID,
Role: UserRoleUser,
} }
err := am.Store.SaveAccount(context.Background(), account) err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: dnsRegularUserID,
AccountID: dnsAccountID,
Role: UserRoleUser,
})
if err != nil { if err != nil {
return nil, err return "", err
} }
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
if err != nil { if err != nil {
return nil, err return "", err
} }
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
if err != nil { if err != nil {
return nil, err return "", err
} }
account, err = am.Store.GetAccount(context.Background(), account.Id) peer1, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer1.Key)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer1, err = account.FindPeerByPubKey(peer1.Key) _, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer2.Key)
if err != nil { if err != nil {
return nil, err return "", err
} }
_, err = account.FindPeerByPubKey(peer2.Key) err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{
{
ID: dnsGroup1ID,
AccountID: dnsAccountID,
Peers: []string{peer1.ID},
Name: dnsGroup1ID,
},
{
ID: dnsGroup2ID,
AccountID: dnsAccountID,
Name: dnsGroup2ID,
},
})
if err != nil { if err != nil {
return nil, err return "", err
} }
newGroup1 := &group.Group{ allGroup, err := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, dnsAccountID, "All")
ID: dnsGroup1ID,
Peers: []string{peer1.ID},
Name: dnsGroup1ID,
}
newGroup2 := &group.Group{
ID: dnsGroup2ID,
Name: dnsGroup2ID,
}
account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2
allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
return nil, err return "", err
} }
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{ err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &dns.NameServerGroup{
ID: dnsNSGroup1, ID: dnsNSGroup1,
Name: "ns-group-1", AccountID: dnsAccountID,
Name: "ns-group-1",
NameServers: []dns.NameServer{{ NameServers: []dns.NameServer{{
IP: netip.MustParseAddr(savedPeer1.IP.String()), IP: netip.MustParseAddr(savedPeer1.IP.String()),
NSType: dns.UDPNameServerType, NSType: dns.UDPNameServerType,
@ -323,14 +322,12 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
Primary: true, Primary: true,
Enabled: true, Enabled: true,
Groups: []string{allGroup.ID}, Groups: []string{allGroup.ID},
} })
err = am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
} }
return am.Store.GetAccount(context.Background(), account.Id) return dnsAccountID, nil
} }
func generateTestData(size int) nbdns.Config { func generateTestData(size int) nbdns.Config {

View File

@ -7,35 +7,35 @@ import (
"time" "time"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/stretchr/testify/require"
) )
type MockStore struct { type MockStore struct {
Store Store
account *Account accountID string
} }
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account { //func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
return []*Account{s.account} // return []*Account{s.account}
} //}
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) { //func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
_, ok := s.account.Peers[peerId] //
if ok { // _, ok := s.account.Peers[peerId]
return s.account, nil // if ok {
} // return s.account, nil
// }
return nil, status.NewPeerNotFoundError(peerId) //
} // return nil, status.NewPeerNotFoundError(peerId)
//}
type MocAccountManager struct { type MocAccountManager struct {
AccountManager AccountManager
store *MockStore store *MockStore
} }
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, _ string) error {
delete(a.store.account.Peers, peerID) return a.store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID)
return nil //nolint:nil
} }
func TestNewManager(t *testing.T) { func TestNewManager(t *testing.T) {
@ -44,23 +44,26 @@ func TestNewManager(t *testing.T) {
return startTime return startTime
} }
store := &MockStore{} store := &MockStore{
Store: newStore(t),
}
am := MocAccountManager{ am := MocAccountManager{
store: store, store: store,
} }
numberOfPeers := 5 numberOfPeers := 5
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background()) mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers { peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers)) require.NoError(t, err, "failed to get account peers")
} require.Equal(t, numberOfPeers, len(peers), "failed to cleanup ephemeral peers")
} }
func TestNewManagerPeerConnected(t *testing.T) { func TestNewManagerPeerConnected(t *testing.T) {
@ -76,19 +79,23 @@ func TestNewManagerPeerConnected(t *testing.T) {
numberOfPeers := 5 numberOfPeers := 5
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
require.NoError(t, err, "failed to get peer")
mgr.OnPeerConnected(context.Background(), peer)
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background()) mgr.cleanup(context.Background())
expected := numberOfPeers + 1 peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
if len(store.account.Peers) != expected { require.NoError(t, err, "failed to get account peers")
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) require.Equal(t, numberOfPeers+1, len(peers), "failed to cleanup ephemeral peers")
}
} }
func TestNewManagerPeerDisconnected(t *testing.T) { func TestNewManagerPeerDisconnected(t *testing.T) {
@ -104,43 +111,64 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
numberOfPeers := 5 numberOfPeers := 5
numberOfEphemeralPeers := 3 numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers) err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am) mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background()) mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
require.NoError(t, err, "failed to get account peers")
for _, v := range peers {
mgr.OnPeerConnected(context.Background(), v)
} }
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
require.NoError(t, err, "failed to get peer")
mgr.OnPeerDisconnected(context.Background(), peer)
startTime = startTime.Add(ephemeralLifeTime + 1) startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background()) mgr.cleanup(context.Background())
peers, err = store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
require.NoError(t, err, "failed to get account peers")
expected := numberOfPeers + numberOfEphemeralPeers - 1 expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected { require.Equal(t, expected, len(peers), "failed to cleanup ephemeral peers")
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
} }
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) error {
store.account = newAccountWithId(context.Background(), "my account", "", "") accountID := "my account"
err := newAccountWithId(context.Background(), store, accountID, "", "")
if err != nil {
return err
}
store.accountID = accountID
for i := 0; i < numberOfPeers; i++ { for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i) peerId := fmt.Sprintf("peer_%d", i)
p := &nbpeer.Peer{ p := &nbpeer.Peer{
ID: peerId, ID: peerId,
AccountID: accountID,
Ephemeral: false, Ephemeral: false,
} }
store.account.Peers[p.ID] = p err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, p)
if err != nil {
return err
}
} }
for i := 0; i < numberOfEphemeralPeers; i++ { for i := 0; i < numberOfEphemeralPeers; i++ {
peerId := fmt.Sprintf("ephemeral_peer_%d", i) peerId := fmt.Sprintf("ephemeral_peer_%d", i)
p := &nbpeer.Peer{ p := &nbpeer.Peer{
ID: peerId, ID: peerId,
AccountID: accountID,
Ephemeral: true, Ephemeral: true,
} }
store.account.Peers[p.ID] = p err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, p)
if err != nil {
return err
}
} }
return nil
} }

View File

@ -328,25 +328,30 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
} }
routeResource := &route.Route{ routeResource := &route.Route{
ID: "example route", ID: "example route",
Groups: []string{groupForRoute.ID}, AccountID: accountID,
Groups: []string{groupForRoute.ID},
} }
routePeerGroupResource := &route.Route{ routePeerGroupResource := &route.Route{
ID: "example route with peer groups", ID: "example route with peer groups",
AccountID: accountID,
PeerGroups: []string{groupForRoute2.ID}, PeerGroups: []string{groupForRoute2.ID},
} }
nameServerGroup := &nbdns.NameServerGroup{ nameServerGroup := &nbdns.NameServerGroup{
ID: "example name server group", ID: "example name server group",
Groups: []string{groupForNameServerGroups.ID}, AccountID: accountID,
Groups: []string{groupForNameServerGroups.ID},
} }
policy := &Policy{ policy := &Policy{
ID: "example policy", ID: "example policy",
AccountID: accountID,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: "example policy rule", ID: "example policy rule",
PolicyID: "example policy",
Destinations: []string{groupForPolicies.ID}, Destinations: []string{groupForPolicies.ID},
}, },
}, },
@ -354,35 +359,60 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
setupKey := &SetupKey{ setupKey := &SetupKey{
Id: "example setup key", Id: "example setup key",
AccountID: accountID,
AutoGroups: []string{groupForSetupKeys.ID}, AutoGroups: []string{groupForSetupKeys.ID},
} }
user := &User{ user := &User{
Id: "example user", Id: "example user",
AccountID: accountID,
AutoGroups: []string{groupForUsers.ID}, AutoGroups: []string{groupForUsers.ID},
} }
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
account.Policies = append(account.Policies, policy)
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user
err := am.Store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routeResource)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) if err != nil {
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) return nil, nil, err
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) }
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
acc, err := am.Store.GetAccount(context.Background(), account.Id) err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routePeerGroupResource)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameServerGroup)
if err != nil {
return nil, nil, err
}
err = am.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, user)
if err != nil {
return nil, nil, err
}
err = am.SaveGroups(context.Background(), accountID, groupAdminUserID, []*nbgroup.Group{
groupForRoute, groupForRoute2, groupForNameServerGroups, groupForPolicies,
groupForSetupKeys, groupForUsers, groupForUsers, groupForIntegration,
})
if err != nil {
return nil, nil, err
}
acc, err := am.Store.GetAccount(context.Background(), accountID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -68,7 +68,7 @@ func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
return nil, fmt.Errorf("unknown group name") return nil, fmt.Errorf("unknown group name")
}, },
GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { GetUserPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return maps.Values(TestPeers), nil return maps.Values(TestPeers), nil
}, },
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {

View File

@ -33,7 +33,8 @@ var testAccount = &server.Account{
Domain: domain, Domain: domain,
Users: map[string]*server.User{ Users: map[string]*server.User{
userID: { userID: {
Id: userID, Id: userID,
AccountID: accountID,
PATs: map[string]*server.PersonalAccessToken{ PATs: map[string]*server.PersonalAccessToken{
tokenID: { tokenID: {
ID: tokenID, ID: tokenID,
@ -49,11 +50,11 @@ var testAccount = &server.Account{
}, },
} }
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) {
if token == PAT { if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
} }
return nil, nil, nil, fmt.Errorf("PAT invalid") return nil, nil, "", "", fmt.Errorf("PAT invalid")
} }
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) { func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
) )
authMiddleware := NewAuthMiddleware( authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT, mockGetAccountInfoFromPAT,
mockValidateAndParseToken, mockValidateAndParseToken,
mockMarkPATUsed, mockMarkPATUsed,
mockCheckUserAccessByJWTGroups, mockCheckUserAccessByJWTGroups,

View File

@ -39,6 +39,68 @@ const (
) )
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: "test_id",
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: "test_id",
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: "test_id",
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP("100.67.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
Serial: 51,
},
}
return &PeersHandler{ return &PeersHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@ -64,77 +126,37 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
} }
return p, nil return p, nil
}, },
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { GetUserPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil return peers, nil
}, },
ListPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil
},
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
peersID := make([]string, len(peers))
for _, peer := range peers {
peersID = append(peersID, peer.ID)
}
return []*nbgroup.Group{
{
ID: "group1",
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: peersID,
},
}, nil
},
GetDNSDomainFunc: func() string { GetDNSDomainFunc: func() string {
return "netbird.selfhosted" return "netbird.selfhosted"
}, },
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil return claims.AccountId, claims.UserId, nil
}, },
GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) {
return account, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: accountID,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: accountID,
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP("100.67.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
Serial: 51,
},
}
return account, nil return account, nil
}, },
HasConnectedChannelFunc: func(peerID string) bool { HasConnectedChannelFunc: func(peerID string) bool {

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -381,14 +382,14 @@ func TestCreateNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
outNSGroup, err := am.CreateNameServerGroup( outNSGroup, err := am.CreateNameServerGroup(
context.Background(), context.Background(),
account.Id, accountID,
testCase.inputArgs.name, testCase.inputArgs.name,
testCase.inputArgs.description, testCase.inputArgs.description,
testCase.inputArgs.nameServers, testCase.inputArgs.nameServers,
@ -609,20 +610,16 @@ func TestSaveNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup testCase.existingNSGroup.AccountID = accountID
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testCase.existingNSGroup)
err = am.Store.SaveAccount(context.Background(), account) require.NoError(t, err, "failed to save existing nameserver group")
if err != nil {
t.Error("account should be saved")
}
var nsGroupToSave *nbdns.NameServerGroup var nsGroupToSave *nbdns.NameServerGroup
if !testCase.skipCopying { if !testCase.skipCopying {
nsGroupToSave = testCase.existingNSGroup.Copy() nsGroupToSave = testCase.existingNSGroup.Copy()
@ -651,7 +648,7 @@ func TestSaveNameServerGroup(t *testing.T) {
} }
} }
err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave) err = am.SaveNameServerGroup(context.Background(), accountID, userID, nsGroupToSave)
testCase.errFunc(t, err) testCase.errFunc(t, err)
@ -659,13 +656,8 @@ func TestSaveNameServerGroup(t *testing.T) {
return return
} }
account, err = am.Store.GetAccount(context.Background(), account.Id) savedNSGroup, err := am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testCase.expectedNSGroup.ID)
if err != nil { require.NoError(t, err, "failed to get saved nameserver group")
t.Fatal(err)
}
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
require.True(t, saved)
if !testCase.expectedNSGroup.IsEqual(savedNSGroup) { if !testCase.expectedNSGroup.IsEqual(savedNSGroup) {
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup) t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup)
@ -703,32 +695,25 @@ func TestDeleteNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup testingNSGroup.AccountID = accountID
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testingNSGroup)
require.NoError(t, err, "failed to save nameserver group")
err = am.Store.SaveAccount(context.Background(), account) err = am.DeleteNameServerGroup(context.Background(), accountID, testingNSGroup.ID, userID)
if err != nil {
t.Error("failed to save account")
}
err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID)
if err != nil { if err != nil {
t.Error("deleting nameserver group failed with error: ", err) t.Error("deleting nameserver group failed with error: ", err)
} }
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) _, err = am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testingNSGroup.ID)
if err != nil { require.NotNil(t, err)
t.Error("failed to retrieve saved account with error: ", err) sErr, ok := status.FromError(err)
} require.True(t, ok, "error should be a status error")
assert.Equal(t, status.NotFound, sErr.Type(), "nameserver group shouldn't be found after delete")
_, found := savedAccount.NameServerGroups[testingNSGroup.ID]
if found {
t.Error("nameserver group shouldn't be found after delete")
}
} }
func TestGetNameServerGroup(t *testing.T) { func TestGetNameServerGroup(t *testing.T) {
@ -738,12 +723,12 @@ func TestGetNameServerGroup(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestNSAccount(t, am) accountID, err := initTestNSAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID) foundGroup, err := am.GetNameServerGroup(context.Background(), accountID, testUserID, existingNSGroupID)
if err != nil { if err != nil {
t.Error("getting existing nameserver group failed with error: ", err) t.Error("getting existing nameserver group failed with error: ", err)
} }
@ -752,7 +737,7 @@ func TestGetNameServerGroup(t *testing.T) {
t.Error("got a nil group while getting nameserver group with ID") t.Error("got a nil group while getting nameserver group with ID")
} }
_, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing") _, err = am.GetNameServerGroup(context.Background(), accountID, testUserID, "not existing")
if err == nil { if err == nil {
t.Error("getting not existing nameserver group should return error, got nil") t.Error("getting not existing nameserver group should return error, got nil")
} }
@ -784,8 +769,12 @@ func createNSStore(t *testing.T) (Store, error) {
return store, nil return store, nil
} }
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper() t.Helper()
accountID := "testingAcc"
userID := testUserID
domain := "example.com"
peer1 := &nbpeer.Peer{ peer1 := &nbpeer.Peer{
Key: nsGroupPeer1Key, Key: nsGroupPeer1Key,
Name: "test-host1@netbird.io", Name: "test-host1@netbird.io",
@ -816,6 +805,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
} }
existingNSGroup := nbdns.NameServerGroup{ existingNSGroup := nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
AccountID: accountID,
Name: existingNSGroupName, Name: existingNSGroupName,
Description: "", Description: "",
NameServers: []nbdns.NameServer{ NameServers: []nbdns.NameServer{
@ -834,42 +824,42 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
Enabled: true, Enabled: true,
} }
accountID := "testingAcc" err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
userID := testUserID
domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain)
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
newGroup1 := &nbgroup.Group{
ID: group1ID,
Name: group1ID,
}
newGroup2 := &nbgroup.Group{
ID: group2ID,
Name: group2ID,
}
account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
}
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &existingNSGroup)
if err != nil {
return "", err
}
err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*nbgroup.Group{
{
ID: group1ID,
AccountID: accountID,
Name: group1ID,
},
{
ID: group2ID,
AccountID: accountID,
Name: group2ID,
},
})
if err != nil {
return "", err
} }
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1)
if err != nil { if err != nil {
return nil, err return "", err
} }
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2)
if err != nil { if err != nil {
return nil, err return "", err
} }
return account, nil return accountID, nil
} }
func TestValidateDomain(t *testing.T) { func TestValidateDomain(t *testing.T) {

View File

@ -467,21 +467,25 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
accountID := "test_account" accountID := "test_account"
adminUser := "account_creator" adminUser := "account_creator"
someUser := "some_user" someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "") err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
account.Users[someUser] = &User{ require.NoError(t, err, "failed to create account")
Id: someUser,
Role: UserRoleUser,
}
account.Settings.RegularUsersViewBlocked = false
err = manager.Store.SaveAccount(context.Background(), account) err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
if err != nil { Id: someUser,
t.Fatal(err) AccountID: accountID,
return Role: UserRoleUser,
} })
require.NoError(t, err, "failed to create user")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = false
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
require.NoError(t, err, "failed to save account settings")
// two peers one added by a regular user and one with a setup key // two peers one added by a regular user and one with a setup key
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
if err != nil { if err != nil {
t.Fatal("error creating setup key") t.Fatal("error creating setup key")
return return
@ -535,7 +539,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
assert.NotNil(t, peer) assert.NotNil(t, peer)
// delete the all-to-all policy so that user's peer1 has no access to peer2 // delete the all-to-all policy so that user's peer1 has no access to peer2
for _, policy := range account.Policies { accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account policies")
for _, policy := range accountPolicies {
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -654,21 +661,33 @@ func TestDefaultAccountManager_GetUserPeers(t *testing.T) {
accountID := "test_account" accountID := "test_account"
adminUser := "account_creator" adminUser := "account_creator"
someUser := "some_user" someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[someUser] = &User{ err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
require.NoError(t, err, "failed to create account")
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: someUser, Id: someUser,
AccountID: accountID,
Role: testCase.role, Role: testCase.role,
IsServiceUser: testCase.isServiceUser, IsServiceUser: testCase.isServiceUser,
} })
account.Policies = []*Policy{} require.NoError(t, err, "failed to create user")
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = manager.Store.SaveAccount(context.Background(), account) accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
if err != nil { require.NoError(t, err, "failed to get account policies")
t.Fatal(err)
return for _, policy := range accountPolicies {
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
require.NoError(t, err, "failed to delete policy")
} }
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
require.NoError(t, err, "failed to save account settings")
peerKey1, err := wgtypes.GeneratePrivateKey() peerKey1, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -724,10 +743,18 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
adminUser := "account_creator" adminUser := "account_creator"
regularUser := "regular_user" regularUser := "regular_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "") err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
account.Users[regularUser] = &User{ if err != nil {
Id: regularUser, return nil, "", "", err
Role: UserRoleUser, }
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: regularUser,
AccountID: accountID,
Role: UserRoleUser,
})
if err != nil {
return nil, "", "", err
} }
// Create peers // Create peers
@ -741,31 +768,40 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
UserID: regularUser, UserID: regularUser,
} }
account.Peers[peer.ID] = peer err = manager.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer)
if err != nil {
return nil, "", "", err
}
} }
// Create groups and policies // Create groups and policies
account.Policies = make([]*Policy, 0, groups)
for i := 0; i < groups; i++ { for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i) groupID := fmt.Sprintf("group-%d", i)
group := &nbgroup.Group{ group := &nbgroup.Group{
ID: groupID, ID: groupID,
Name: fmt.Sprintf("Group %d", i), AccountID: accountID,
Name: fmt.Sprintf("Group %d", i),
} }
for j := 0; j < peers/groups; j++ { for j := 0; j < peers/groups; j++ {
peerIndex := i*(peers/groups) + j peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
} }
account.Groups[groupID] = group
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
return nil, "", "", err
}
// Create a policy for this group // Create a policy for this group
policy := &Policy{ policy := &Policy{
ID: fmt.Sprintf("policy-%d", i), ID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Policy for Group %d", i), AccountID: accountID,
Enabled: true, Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
Rules: []*PolicyRule{ Rules: []*PolicyRule{
{ {
ID: fmt.Sprintf("rule-%d", i), ID: fmt.Sprintf("rule-%d", i),
PolicyID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Rule for Group %d", i), Name: fmt.Sprintf("Rule for Group %d", i),
Enabled: true, Enabled: true,
Sources: []string{groupID}, Sources: []string{groupID},
@ -776,22 +812,23 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
}, },
}, },
} }
account.Policies = append(account.Policies, policy)
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
if err != nil {
return nil, "", "", err
}
} }
account.PostureChecks = []*posture.Checks{ err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, &posture.Checks{
{ ID: "PostureChecksAll",
ID: "PostureChecksAll", AccountID: accountID,
Name: "All", Name: "All",
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1", MinVersion: "0.0.1",
},
}, },
}, },
} })
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, "", "", err return nil, "", "", err
} }

View File

@ -45,7 +45,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
} }
if err = am.validatePostureChecks(ctx, accountID, postureChecks); err != nil { if err = am.validatePostureChecks(ctx, accountID, postureChecks); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) return status.Errorf(status.InvalidArgument, err.Error()) //nolint
} }
updateAccountPeers, err := am.arePostureCheckChangesAffectPeers(ctx, accountID, postureChecks.ID, isUpdate) updateAccountPeers, err := am.arePostureCheckChangesAffectPeers(ctx, accountID, postureChecks.ID, isUpdate)

View File

@ -7,6 +7,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/group"
@ -26,22 +27,22 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestPostureChecksAccount(am) accountID, err := initTestPostureChecksAccount(am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
t.Run("Generic posture check flow", func(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) {
// regular users can not create checks // regular users can not create checks
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}, false) err := am.SavePostureChecks(context.Background(), accountID, regularUserID, &posture.Checks{}, false)
assert.Error(t, err) assert.Error(t, err)
// regular users cannot list check // regular users cannot list check
_, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID) _, err = am.ListPostureChecks(context.Background(), accountID, regularUserID)
assert.Error(t, err) assert.Error(t, err)
// should be possible to create posture check with uniq name // should be possible to create posture check with uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: postureCheckID, ID: postureCheckID,
Name: postureCheckName, Name: postureCheckName,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
@ -53,12 +54,12 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// admin users can list check // admin users can list check
checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID) checks, err := am.ListPostureChecks(context.Background(), accountID, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, checks, 1) assert.Len(t, checks, 1)
// should not be possible to create posture check with non uniq name // should not be possible to create posture check with non uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: "new-id", ID: "new-id",
Name: postureCheckName, Name: postureCheckName,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
@ -74,7 +75,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
// admins can update posture checks // admins can update posture checks
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: postureCheckID, ID: postureCheckID,
Name: postureCheckName, Name: postureCheckName,
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
@ -86,41 +87,44 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// users should not be able to delete posture checks // users should not be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) err = am.DeletePostureChecks(context.Background(), accountID, postureCheckID, regularUserID)
assert.Error(t, err) assert.Error(t, err)
// admin should be able to delete posture checks // admin should be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) err = am.DeletePostureChecks(context.Background(), accountID, postureCheckID, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) checks, err = am.ListPostureChecks(context.Background(), accountID, adminUserID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, checks, 0) assert.Len(t, checks, 0)
}) })
} }
func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { func initTestPostureChecksAccount(am *DefaultAccountManager) (string, error) {
accountID := "testingAccount" accountID := "testingAccount"
domain := "example.com" domain := "example.com"
admin := &User{ err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
Id: adminUserID,
Role: UserRoleAdmin,
}
user := &User{
Id: regularUserID,
Role: UserRoleUser,
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Users[admin.Id] = admin
account.Users[user.Id] = user
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
} }
return am.Store.GetAccount(context.Background(), account.Id) err = am.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
{
Id: adminUserID,
AccountID: accountID,
Role: UserRoleAdmin,
},
{
Id: regularUserID,
AccountID: accountID,
Role: UserRoleUser,
},
})
if err != nil {
return "", err
}
return accountID, nil
} }
func TestPostureCheckAccountPeersUpdate(t *testing.T) { func TestPostureCheckAccountPeersUpdate(t *testing.T) {
@ -169,7 +173,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, false)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@ -192,7 +196,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@ -255,7 +259,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@ -303,7 +307,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
} }
}) })
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, false)
assert.NoError(t, err) assert.NoError(t, err)
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
@ -337,7 +341,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@ -384,7 +388,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0", MinVersion: "0.29.0",
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@ -429,7 +433,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}, },
}, },
} }
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err) assert.NoError(t, err)
select { select {
@ -441,79 +445,113 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
} }
func TestArePostureCheckChangesAffectingPeers(t *testing.T) { func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
account := &Account{ manager, err := createManager(t)
Policies: []*Policy{ require.NoError(t, err, "failed to create account manager")
{
ID: "policyA", accountID, err := initTestPostureChecksAccount(manager)
Rules: []*PolicyRule{ require.NoError(t, err, "failed to init testing account")
{
Enabled: true, groupA := &group.Group{
Sources: []string{"groupA"}, ID: "groupA",
Destinations: []string{"groupA"}, AccountID: accountID,
}, Peers: []string{"peer1"},
},
SourcePostureChecks: []string{"checkA"},
},
},
Groups: map[string]*group.Group{
"groupA": {
ID: "groupA",
Peers: []string{"peer1"},
},
"groupB": {
ID: "groupB",
Peers: []string{},
},
},
PostureChecks: []*posture.Checks{
{
ID: "checkA",
},
{
ID: "checkB",
},
},
} }
groupB := &group.Group{
ID: "groupA",
AccountID: accountID,
Peers: []string{},
}
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups")
policy := &Policy{
ID: "policyA",
AccountID: accountID,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{"checkA"},
}
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to save policy")
postureCheckA := &posture.Checks{
ID: "checkA",
AccountID: accountID,
}
err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureCheckA)
require.NoError(t, err, "failed to save postureCheckA")
postureCheckB := &posture.Checks{
ID: "checkB",
AccountID: accountID,
}
err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureCheckB)
require.NoError(t, err, "failed to save postureCheckB")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkB", true) result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkB", true)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check does not exist", func(t *testing.T) { t.Run("posture check does not exist", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "unknown", false) result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "unknown", false)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupB"} policy.Rules[0].Sources = []string{"groupB"}
account.Policies[0].Rules[0].Destinations = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupA"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to update policy")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupA"} policy.Rules[0].Sources = []string{"groupA"}
account.Policies[0].Rules[0].Destinations = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupB"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to update policy")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.True(t, result) assert.True(t, result)
}) })
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} policy.Rules[0].Sources = []string{"nonExistentGroup"}
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} policy.Rules[0].Destinations = []string{"nonExistentGroup"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to update policy")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
account.Groups["groupA"].Peers = []string{} groupA.Peers = []string{}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true) err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
require.NoError(t, err, "failed to save groups")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.False(t, result) assert.False(t, result)
}) })
} }

View File

@ -5,12 +5,15 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
@ -427,21 +430,22 @@ func TestCreateRoute(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Errorf("failed to init testing account: %s", err) t.Errorf("failed to init testing account: %s", err)
} }
if testCase.createInitRoute { if testCase.createInitRoute {
groupAll, errInit := account.GetGroupAll() groupAll, errInit := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
require.NoError(t, errInit) require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
_, errInit = am.CreateRoute(context.Background(), accountID, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
require.NoError(t, errInit) require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) _, errInit = am.CreateRoute(context.Background(), accountID, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
require.NoError(t, errInit) require.NoError(t, errInit)
} }
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) outRoute, err := am.CreateRoute(context.Background(), accountID, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
testCase.errFunc(t, err) testCase.errFunc(t, err)
@ -917,14 +921,15 @@ func TestSaveRoute(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
if testCase.createInitRoute { if testCase.createInitRoute {
account.Routes["initRoute"] = &route.Route{ err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, &route.Route{
ID: "initRoute", ID: "initRoute",
AccountID: accountID,
Network: existingNetwork, Network: existingNetwork,
NetID: existingRouteID, NetID: existingRouteID,
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
@ -934,18 +939,14 @@ func TestSaveRoute(t *testing.T) {
Metric: 9999, Metric: 9999,
Enabled: true, Enabled: true,
Groups: []string{routeGroup1}, Groups: []string{routeGroup1},
} })
require.NoError(t, err, "failed to save init route")
} }
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute err = am.SaveRoute(context.Background(), accountID, userID, testCase.existingRoute)
require.NoError(t, err, "failed to save existing route")
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Error("account should be saved")
}
var routeToSave *route.Route var routeToSave *route.Route
if !testCase.skipCopying { if !testCase.skipCopying {
routeToSave = testCase.existingRoute.Copy() routeToSave = testCase.existingRoute.Copy()
if testCase.newPeer != nil { if testCase.newPeer != nil {
@ -977,21 +978,15 @@ func TestSaveRoute(t *testing.T) {
} }
} }
err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave) err = am.SaveRoute(context.Background(), accountID, userID, routeToSave)
testCase.errFunc(t, err) testCase.errFunc(t, err)
if !testCase.shouldCreate { if !testCase.shouldCreate {
return return
} }
account, err = am.Store.GetAccount(context.Background(), account.Id) savedRoute, err := am.Store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(testCase.expectedRoute.ID))
if err != nil { require.NoError(t, err, "failed to retrieve saved route")
t.Fatal(err)
}
savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
require.True(t, saved)
if !testCase.expectedRoute.IsEqual(savedRoute) { if !testCase.expectedRoute.IsEqual(savedRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute) t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute)
@ -1019,32 +1014,26 @@ func TestDeleteRoute(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
account.Routes[testingRoute.ID] = testingRoute testingRoute.AccountID = accountID
err = am.SaveRoute(context.Background(), accountID, userID, testingRoute)
require.NoError(t, err, "failed to save testing route")
err = am.Store.SaveAccount(context.Background(), account) err = am.DeleteRoute(context.Background(), accountID, testingRoute.ID, userID)
if err != nil {
t.Error("failed to save account")
}
err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID)
if err != nil { if err != nil {
t.Error("deleting route failed with error: ", err) t.Error("deleting route failed with error: ", err)
} }
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) _, err = am.GetRoute(context.Background(), accountID, testingRoute.ID, userID)
if err != nil { require.NotNil(t, err)
t.Error("failed to retrieve saved account with error: ", err)
}
_, found := savedAccount.Routes[testingRoute.ID] sErr, ok := status.FromError(err)
if found { require.True(t, ok)
t.Error("route shouldn't be found after delete") require.Equal(t, codes.NotFound, sErr.Type())
}
} }
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
@ -1066,16 +1055,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { require.NoError(t, err, "failed to init testing account")
t.Error("failed to init testing account")
}
newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) newRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, newRoute.Enabled, true) require.Equal(t, newRoute.Enabled, true)
@ -1091,7 +1078,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups { for _, group := range groups {
@ -1103,21 +1090,21 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
} }
} }
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID) err = am.GroupDeletePeer(context.Background(), accountID, groupHA1.ID, peer2ID)
require.NoError(t, err) require.NoError(t, err)
peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID) peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes") assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes")
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID) err = am.GroupDeletePeer(context.Background(), accountID, groupHA2.ID, peer4ID)
require.NoError(t, err) require.NoError(t, err)
peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID) peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID) err = am.GroupAddPeer(context.Background(), accountID, groupHA2.ID, peer4ID)
require.NoError(t, err) require.NoError(t, err)
peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID) peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID)
@ -1128,7 +1115,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes") assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes")
err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID) err = am.DeleteRoute(context.Background(), accountID, newRoute.ID, userID)
require.NoError(t, err) require.NoError(t, err)
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
@ -1158,16 +1145,14 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
account, err := initTestRouteAccount(t, am) accountID, err := initTestRouteAccount(t, am)
if err != nil { require.NoError(t, err, "failed to init testing account")
t.Error("failed to init testing account")
}
newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) createdRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
require.NoError(t, err) require.NoError(t, err)
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@ -1181,7 +1166,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
expectedRoute := enabledRoute.Copy() expectedRoute := enabledRoute.Copy()
expectedRoute.Peer = peer1Key expectedRoute.Peer = peer1Key
err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute) err = am.SaveRoute(context.Background(), accountID, userID, enabledRoute)
require.NoError(t, err) require.NoError(t, err)
peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID)
@ -1193,7 +1178,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group") require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group")
err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID) err = am.GroupAddPeer(context.Background(), accountID, routeGroup1, peer2ID)
require.NoError(t, err) require.NoError(t, err)
peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID) peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID)
@ -1206,10 +1191,10 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
Name: "peer1 group", Name: "peer1 group",
Peers: []string{peer1ID}, Peers: []string{peer1ID},
} }
err = am.SaveGroup(context.Background(), account.Id, userID, newGroup) err = am.SaveGroup(context.Background(), accountID, userID, newGroup)
require.NoError(t, err) require.NoError(t, err)
rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") rules, err := am.ListPolicies(context.Background(), accountID, "testingUser")
require.NoError(t, err) require.NoError(t, err)
defaultRule := rules[0] defaultRule := rules[0]
@ -1219,10 +1204,10 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID}
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) err = am.SavePolicy(context.Background(), accountID, userID, newPolicy, false)
require.NoError(t, err) require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) err = am.DeletePolicy(context.Background(), accountID, defaultRule.ID, userID)
require.NoError(t, err) require.NoError(t, err)
peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@ -1233,7 +1218,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2") require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2")
err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID) err = am.DeleteRoute(context.Background(), accountID, enabledRoute.ID, userID)
require.NoError(t, err) require.NoError(t, err)
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
@ -1267,179 +1252,103 @@ func createRouterStore(t *testing.T) (Store, error) {
return store, nil return store, nil
} }
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper() t.Helper()
accountID := "testingAcc" accountID := "testingAcc"
domain := "example.com" domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain) err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, err return "", err
} }
ips := account.getTakenIPs() createPeer := func(peerID, peerKey, peerName, dnsLabel, kernel, core, platform, os string) (*nbpeer.Peer, error) {
peer1IP, err := AllocatePeerIP(account.Network.Net, ips) ips, err := am.Store.GetTakenIPs(context.Background(), LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
network, err := am.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerIP, err := AllocatePeerIP(network.Net, ips)
if err != nil {
return nil, err
}
peer := &nbpeer.Peer{
IP: peerIP,
ID: peerID,
Key: peerKey,
Name: peerName,
DNSLabel: dnsLabel,
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: peerName,
GoOS: strings.ToLower(os),
Kernel: kernel,
Core: core,
Platform: platform,
OS: os,
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
if err := am.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer); err != nil {
return nil, err
}
return peer, nil
}
// Create peers
peer1, err := createPeer(peer1ID, peer1Key, "test-host1@netbird.io", "test-host1", "Linux", "21.04", "x86_64", "Ubuntu")
if err != nil { if err != nil {
return nil, err return "", err
} }
peer2, err := createPeer(peer2ID, peer2Key, "test-host2@netbird.io", "test-host2", "Linux", "21.04", "x86_64", "Ubuntu")
peer1 := &nbpeer.Peer{
IP: peer1IP,
ID: peer1ID,
Key: peer1Key,
Name: "test-host1@netbird.io",
DNSLabel: "test-host1",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host1@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer1.ID] = peer1
ips = account.getTakenIPs()
peer2IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer3, err := createPeer(peer3ID, peer3Key, "test-host3@netbird.io", "test-host3", "Darwin", "13.4.1", "arm64", "darwin")
peer2 := &nbpeer.Peer{
IP: peer2IP,
ID: peer2ID,
Key: peer2Key,
Name: "test-host2@netbird.io",
DNSLabel: "test-host2",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host2@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer2.ID] = peer2
ips = account.getTakenIPs()
peer3IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer4, err := createPeer(peer4ID, peer4Key, "test-host4@netbird.io", "test-host4", "Linux", "21.04", "x86_64", "Ubuntu")
peer3 := &nbpeer.Peer{
IP: peer3IP,
ID: peer3ID,
Key: peer3Key,
Name: "test-host3@netbird.io",
DNSLabel: "test-host3",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host3@netbird.io",
GoOS: "darwin",
Kernel: "Darwin",
Core: "13.4.1",
Platform: "arm64",
OS: "darwin",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer3.ID] = peer3
ips = account.getTakenIPs()
peer4IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer5, err := createPeer(peer5ID, peer5Key, "test-host5@netbird.io", "test-host5", "Linux", "21.04", "x86_64", "Ubuntu")
peer4 := &nbpeer.Peer{
IP: peer4IP,
ID: peer4ID,
Key: peer4Key,
Name: "test-host4@netbird.io",
DNSLabel: "test-host4",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host4@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer4.ID] = peer4
ips = account.getTakenIPs()
peer5IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, err return "", err
} }
peer5 := &nbpeer.Peer{ groupAll, err := am.GetGroupByName(context.Background(), "All", accountID)
IP: peer5IP, if err != nil {
ID: peer5ID, return "", err
Key: peer5Key,
Name: "test-host5@netbird.io",
DNSLabel: "test-host5",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host5@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
} }
account.Peers[peer5.ID] = peer5
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
groupAll, err := account.GetGroupAll()
if err != nil {
return nil, err
}
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID) err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID)
if err != nil { if err != nil {
return nil, err return "", err
} }
newGroup := []*nbgroup.Group{ newGroups := []*nbgroup.Group{
{ {
ID: routeGroup1, ID: routeGroup1,
Name: routeGroup1, Name: routeGroup1,
@ -1471,15 +1380,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
Peers: []string{peer1.ID, peer4.ID}, Peers: []string{peer1.ID, peer4.ID},
}, },
} }
err = am.SaveGroups(context.Background(), accountID, userID, newGroups)
for _, group := range newGroup { if err != nil {
err = am.SaveGroup(context.Background(), accountID, userID, group) return "", err
if err != nil {
return nil, err
}
} }
return am.Store.GetAccount(context.Background(), account.Id) return accountID, nil
} }
func TestAccount_getPeersRoutesFirewall(t *testing.T) { func TestAccount_getPeersRoutesFirewall(t *testing.T) {
@ -1783,10 +1689,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
manager, err := createRouterManager(t) manager, err := createRouterManager(t)
require.NoError(t, err, "failed to create account manager") require.NoError(t, err, "failed to create account manager")
account, err := initTestRouteAccount(t, manager) accountID, err := initTestRouteAccount(t, manager)
require.NoError(t, err, "failed to init testing account") require.NoError(t, err, "failed to init testing account")
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
{ {
ID: "groupA", ID: "groupA",
Name: "GroupA", Name: "GroupA",
@ -1832,7 +1738,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}() }()
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute, route.Groups, []string{}, true, userID, route.KeepRoute,
) )
@ -1868,7 +1774,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}() }()
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute, route.Groups, []string{}, true, userID, route.KeepRoute,
) )
@ -1904,7 +1810,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}() }()
newRoute, err := manager.CreateRoute( newRoute, err := manager.CreateRoute(
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
) )
@ -1928,7 +1834,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) err := manager.SaveRoute(context.Background(), accountID, userID, &baseRoute)
require.NoError(t, err) require.NoError(t, err)
select { select {
@ -1946,7 +1852,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID) err := manager.DeleteRoute(context.Background(), accountID, baseRoute.ID, userID)
require.NoError(t, err) require.NoError(t, err)
select { select {
@ -1970,7 +1876,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Groups: []string{routeGroup1}, Groups: []string{routeGroup1},
} }
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
) )
@ -1982,7 +1888,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "groupB", ID: "groupB",
Name: "GroupB", Name: "GroupB",
Peers: []string{peer1ID}, Peers: []string{peer1ID},
@ -2010,7 +1916,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Groups: []string{"groupC"}, Groups: []string{"groupC"},
} }
_, err := manager.CreateRoute( _, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
) )
@ -2022,7 +1928,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done) close(done)
}() }()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "groupC", ID: "groupC",
Name: "GroupC", Name: "GroupC",
Peers: []string{peer1ID}, Peers: []string{peer1ID},

View File

@ -25,12 +25,10 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
} }
userID := "testingUser" userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { require.NoError(t, err, "failed to get or create account ID")
t.Fatal(err)
}
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
{ {
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
@ -49,7 +47,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
expiresIn := time.Hour expiresIn := time.Hour
keyName := "my-test-key" keyName := "my-test-key"
key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, key, err := manager.CreateSetupKey(context.Background(), accountID, keyName, SetupKeyReusable, expiresIn, []string{},
SetupKeyUnlimitedUsage, userID, false) SetupKeyUnlimitedUsage, userID, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -58,7 +56,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
autoGroups := []string{"group_1", "group_2"} autoGroups := []string{"group_1", "group_2"}
newKeyName := "my-new-test-key" newKeyName := "my-new-test-key"
revoked := true revoked := true
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ newKey, err := manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
Id: key.Id, Id: key.Id,
Name: newKeyName, Name: newKeyName,
Revoked: revoked, Revoked: revoked,
@ -72,22 +70,22 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
key.Id, time.Now().UTC(), autoGroups, true) key.Id, time.Now().UTC(), autoGroups, true)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked) ev := getEvent(t, accountID, manager, activity.SetupKeyRevoked)
assert.NotNil(t, ev) assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID) assert.Equal(t, accountID, ev.AccountID)
assert.Equal(t, newKeyName, ev.Meta["name"]) assert.Equal(t, newKeyName, ev.Meta["name"])
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"]) assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, userID, ev.InitiatorID)
assert.Equal(t, key.Id, ev.TargetID) assert.Equal(t, key.Id, ev.TargetID)
groupAll, err := account.GetGroupAll() groupAll, err := manager.GetGroupByName(context.Background(), accountID, "All")
assert.NoError(t, err) require.NoError(t, err)
// saving setup key with All group assigned to auto groups should return error // saving setup key with All group assigned to auto groups should return error
autoGroups = append(autoGroups, groupAll.ID) autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ _, err = manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
Id: key.Id, Id: key.Id,
Name: newKeyName, Name: newKeyName,
Revoked: revoked, Revoked: revoked,
@ -103,12 +101,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
} }
userID := "testingUser" userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { require.NoError(t, err, "failed to get or create account ID")
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
Peers: []string{}, Peers: []string{},
@ -117,7 +113,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_2", ID: "group_2",
Name: "group_name_2", Name: "group_name_2",
Peers: []string{}, Peers: []string{},
@ -126,8 +122,8 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
groupAll, err := account.GetGroupAll() groupAll, err := manager.GetGroupByName(context.Background(), accountID, "All")
assert.NoError(t, err) require.NoError(t, err)
type testCase struct { type testCase struct {
name string name string
@ -170,7 +166,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
for _, tCase := range []testCase{testCase1, testCase2, testCase3} { for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
t.Run(tCase.name, func(t *testing.T) { t.Run(tCase.name, func(t *testing.T) {
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, key, err := manager.CreateSetupKey(context.Background(), accountID, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)
if tCase.expectedFailure { if tCase.expectedFailure {
@ -189,10 +185,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
tCase.expectedUpdatedAt, tCase.expectedGroups, false) tCase.expectedUpdatedAt, tCase.expectedGroups, false)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated) ev := getEvent(t, accountID, manager, activity.SetupKeyCreated)
assert.NotNil(t, ev) assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID) assert.Equal(t, accountID, ev.AccountID)
assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"]) assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"])
assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"])) assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"]) assert.NotEmpty(t, ev.Meta["key"])
@ -208,12 +204,10 @@ func TestGetSetupKeys(t *testing.T) {
} }
userID := "testingUser" userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil { require.NoError(t, err, "failed to get or create account ID")
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
Peers: []string{}, Peers: []string{},
@ -222,7 +216,7 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_2", ID: "group_2",
Name: "group_name_2", Name: "group_name_2",
Peers: []string{}, Peers: []string{},

View File

@ -68,17 +68,23 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
func runLargeTest(t *testing.T, store Store) { func runLargeTest(t *testing.T, store Store) {
t.Helper() t.Helper()
account := newAccountWithId(context.Background(), "account_id", "testuser", "") accountID := "account_id"
groupALL, err := account.GetGroupAll()
if err != nil { err := newAccountWithId(context.Background(), store, accountID, "testuser", "")
t.Fatal(err) assert.NoError(t, err, "failed to create account")
}
groupAll, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
assert.NoError(t, err, "failed to get group All")
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
assert.NoError(t, err, "failed to save setup key")
const numPerAccount = 6000 const numPerAccount = 6000
for n := 0; n < numPerAccount; n++ { for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4() netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) peerID := fmt.Sprintf("%s-peer-%d", accountID, n)
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{
ID: peerID, ID: peerID,
@ -90,16 +96,21 @@ func runLargeTest(t *testing.T, store Store) {
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false, SSHEnabled: false,
} }
account.Peers[peerID] = peer err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer)
group, _ := account.GetGroupAll() assert.NoError(t, err, "failed to save peer")
group.Peers = append(group.Peers, peerID)
user := &User{ err = store.AddPeerToAllGroup(context.Background(), accountID, peerID)
Id: fmt.Sprintf("%s-user-%d", account.Id, n), assert.NoError(t, err, "failed to add peer to all group")
AccountID: account.Id,
} err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
account.Users[user.Id] = user Id: fmt.Sprintf("%s-user-%d", accountID, n),
AccountID: accountID,
})
assert.NoError(t, err, "failed to save user")
route := &route2.Route{ route := &route2.Route{
ID: route2.ID(fmt.Sprintf("network-id-%d", n)), ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
AccountID: accountID,
Description: "base route", Description: "base route",
NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)), NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
Network: netip.MustParsePrefix(netIP.String() + "/24"), Network: netip.MustParsePrefix(netIP.String() + "/24"),
@ -107,22 +118,24 @@ func runLargeTest(t *testing.T, store Store) {
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
Enabled: true, Enabled: true,
Groups: []string{groupALL.ID}, Groups: []string{groupAll.ID},
} }
account.Routes[route.ID] = route err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
assert.NoError(t, err, "failed to save route")
group = &nbgroup.Group{ group := &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n), ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id, AccountID: accountID,
Name: fmt.Sprintf("group-id-%d", n), Name: fmt.Sprintf("group-id-%d", n),
Issued: "api", Issued: "api",
Peers: nil, Peers: nil,
} }
account.Groups[group.ID] = group err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
assert.NoError(t, err, "failed to save group")
nameserver := &nbdns.NameServerGroup{ nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n), ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id, AccountID: accountID,
Name: fmt.Sprintf("nameserver-id-%d", n), Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "", Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}}, NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
@ -132,20 +145,20 @@ func runLargeTest(t *testing.T, store Store) {
Enabled: false, Enabled: false,
SearchDomainsEnabled: false, SearchDomainsEnabled: false,
} }
account.NameServerGroups[nameserver.ID] = nameserver err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameserver)
assert.NoError(t, err, "failed to save nameserver group")
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ = GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
assert.NoError(t, err, "failed to save setup key")
} }
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 1 { if len(store.GetAllAccounts(context.Background())) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
} }
a, err := store.GetAccount(context.Background(), account.Id) a, err := store.GetAccount(context.Background(), accountID)
if a == nil { if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
} }
@ -213,41 +226,49 @@ func TestSqlite_SaveAccount(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) require.NoError(t, err)
accountID := "account_id"
err = newAccountWithId(context.Background(), store, accountID, "testuser", "")
require.NoError(t, err, "failed to create account")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
account.Peers["testpeer"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} })
require.NoError(t, err, "failed to save peer")
err = store.SaveAccount(context.Background(), account) accountID2 := "account_id2"
require.NoError(t, err) err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "")
require.NoError(t, err, "failed to create account")
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey, _ = GenerateDefaultSetupKey() setupKey, _ = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID2
account2.Peers["testpeer2"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID2, &nbpeer.Peer{
Key: "peerkey2", Key: "peerkey2",
IP: net.IP{127, 0, 0, 2}, IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2", Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} })
require.NoError(t, err, "failed to save peer")
err = store.SaveAccount(context.Background(), account2)
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 2 { if len(store.GetAllAccounts(context.Background())) != 2 {
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
} }
a, err := store.GetAccount(context.Background(), account.Id) a, err := store.GetAccount(context.Background(), accountID)
if a == nil { if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
} }
@ -288,36 +309,56 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
accountID := "account_id"
testUserID := "testuser" testUserID := "testuser"
user := NewAdminUser(testUserID) user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": { user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken", ID: "testtoken",
Name: "test token", Name: "test token",
}} }}
account := newAccountWithId(context.Background(), "account_id", testUserID, "") err = newAccountWithId(context.Background(), store, accountID, testUserID, "")
require.NoError(t, err)
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
account.Peers["testpeer"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} })
account.Users[testUserID] = user require.NoError(t, err, "failed to save peer")
err = store.SaveAccount(context.Background(), account) err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
require.NoError(t, err) ID: "testtoken",
UserID: testUserID,
Name: "test token",
})
require.NoError(t, err, "failed to save personal access token")
if len(store.GetAllAccounts(context.Background())) != 1 { accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get all account ids")
if len(accountIDs) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
} }
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
err = store.DeleteAccount(context.Background(), account) err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 0 { accountIDs, err = store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get all account ids after DeleteAccount()")
if len(accountIDs) != 0 {
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
} }
@ -714,19 +755,28 @@ func newSqliteStore(t *testing.T) *SqlStore {
} }
func newAccount(store Store, id int) error { func newAccount(store Store, id int) error {
str := fmt.Sprintf("%s-%d", uuid.New().String(), id) accountID := fmt.Sprintf("%s-%d", uuid.New().String(), id)
account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") userID := accountID + "-testuser"
err := newAccountWithId(context.Background(), store, accountID, userID, "example.com")
if err != nil {
return err
}
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
account.Peers["p"+str] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
Key: "peerkey" + str, if err != nil {
return err
}
return store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{
Key: accountID + "-peerkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} })
return store.SaveAccount(context.Background(), account)
} }
func TestPostgresql_NewStore(t *testing.T) { func TestPostgresql_NewStore(t *testing.T) {
@ -754,39 +804,52 @@ func TestPostgresql_SaveAccount(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "") accountID := "account_id"
err = newAccountWithId(context.Background(), store, accountID, "testuser", "")
require.NoError(t, err, "failed to create account")
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
account.Peers["testpeer"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} })
require.NoError(t, err, "failed to save peer")
err = store.SaveAccount(context.Background(), account) accountID2 := "account_id2"
require.NoError(t, err)
err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "")
require.NoError(t, err, "failed to create account")
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey, _ = GenerateDefaultSetupKey() setupKey, _ = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID2
account2.Peers["testpeer2"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID2, &nbpeer.Peer{
Key: "peerkey2", Key: "peerkey2",
IP: net.IP{127, 0, 0, 2}, IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2", Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} })
require.NoError(t, err, "failed to save peer")
err = store.SaveAccount(context.Background(), account2) accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthUpdate)
require.NoError(t, err) require.NoError(t, err, "failed to get all account ids")
if len(store.GetAllAccounts(context.Background())) != 2 { if len(accountIDs) != 2 {
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
} }
a, err := store.GetAccount(context.Background(), account.Id) a, err := store.GetAccount(context.Background(), accountID)
if a == nil { if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
} }
@ -827,32 +890,49 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
accountID := "account_id"
testUserID := "testuser" testUserID := "testuser"
user := NewAdminUser(testUserID) user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": { user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken", ID: "testtoken",
Name: "test token", Name: "test token",
}} }}
account := newAccountWithId(context.Background(), "account_id", testUserID, "") err = newAccountWithId(context.Background(), store, accountID, testUserID, "")
require.NoError(t, err, "failed to create account")
setupKey, _ := GenerateDefaultSetupKey() setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey setupKey.AccountID = accountID
account.Peers["testpeer"] = &nbpeer.Peer{ err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} })
account.Users[testUserID] = user require.NoError(t, err, "failed to save peer")
err = store.SaveAccount(context.Background(), account) err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
require.NoError(t, err) ID: "testtoken",
UserID: testUserID,
Name: "test token",
})
require.NoError(t, err, "failed to save personal access token")
if len(store.GetAllAccounts(context.Background())) != 1 { accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthUpdate)
require.NoError(t, err, "failed to get all account ids")
if len(accountIDs) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
} }
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
err = store.DeleteAccount(context.Background(), account) err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err) require.NoError(t, err)
@ -1218,7 +1298,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
domain := "example.com" domain := "example.com"
category := "public" category := "public"
IsDomainPrimaryAccount := false IsDomainPrimaryAccount := false
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount)
require.NoError(t, err) require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID) account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err) require.NoError(t, err)
@ -1232,7 +1312,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
domain := "test.com" domain := "test.com"
category := "private" category := "private"
IsDomainPrimaryAccount := true IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount)
require.NoError(t, err) require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID) account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err) require.NoError(t, err)
@ -1246,7 +1326,9 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
domain := "test.com" domain := "test.com"
category := "private" category := "private"
IsDomainPrimaryAccount := true IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, "non-existing-account-id",
domain, category, &IsDomainPrimaryAccount,
)
require.Error(t, err) require.Error(t, err)
}) })
@ -1274,7 +1356,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID)
require.NoError(t, err) require.NoError(t, err)
_, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID)
@ -1290,6 +1372,6 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nonExistingKeyID := "non-existing-key-id" nonExistingKeyID := "non-existing-key-id"
err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
require.Error(t, err) require.Error(t, err)
} }

View File

@ -962,7 +962,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountIDByUser(ctx context.Context,
} }
} }
return "", nil return accountID, nil
} }
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return

View File

@ -43,12 +43,9 @@ const (
func TestUser_CreatePAT_ForSameUser(t *testing.T) { func TestUser_CreatePAT_ForSameUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -84,15 +81,16 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockTargetUserId, Id: mockTargetUserId,
AccountID: mockAccountID,
IsServiceUser: false, IsServiceUser: false,
} })
err := store.SaveAccount(context.Background(), account) assert.NoError(t, err, "failed to create user")
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -106,15 +104,16 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
func TestUser_CreatePAT_ForServiceUser(t *testing.T) { func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockTargetUserId, Id: mockTargetUserId,
AccountID: mockAccountID,
IsServiceUser: true, IsServiceUser: true,
} })
err := store.SaveAccount(context.Background(), account) assert.NoError(t, err, "failed to create user")
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -132,12 +131,9 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -151,12 +147,9 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
func TestUser_CreatePAT_WithEmptyName(t *testing.T) { func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -164,26 +157,22 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
} }
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
assert.Errorf(t, err, "Wrong expiration should thorw error") assert.Errorf(t, err, "Wrong expiration should throw error")
} }
func TestUser_DeletePAT(t *testing.T) { func TestUser_DeletePAT(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: mockUserID, assert.NoError(t, err, "failed to create account")
PATs: map[string]*PersonalAccessToken{
mockTokenID1: { err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: mockTokenID1, ID: mockTokenID1,
HashedToken: mockToken1, UserID: mockUserID,
}, HashedToken: mockToken1,
}, })
} assert.NoError(t, err, "failed to create PAT")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -195,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) {
t.Fatalf("Error when adding PAT to user: %s", err) t.Fatalf("Error when adding PAT to user: %s", err)
} }
account, err = store.GetAccount(context.Background(), mockAccountID) account, err := store.GetAccount(context.Background(), mockAccountID)
if err != nil { if err != nil {
t.Fatalf("Error when getting account: %s", err) t.Fatalf("Error when getting account: %s", err)
} }
@ -206,21 +195,16 @@ func TestUser_DeletePAT(t *testing.T) {
func TestUser_GetPAT(t *testing.T) { func TestUser_GetPAT(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: mockUserID, assert.NoError(t, err, "failed to create account")
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{ err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
mockTokenID1: { ID: mockTokenID1,
ID: mockTokenID1, UserID: mockUserID,
HashedToken: mockToken1, HashedToken: mockToken1,
}, })
}, assert.NoError(t, err, "failed to create PAT")
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -239,25 +223,23 @@ func TestUser_GetPAT(t *testing.T) {
func TestUser_GetAllPATs(t *testing.T) { func TestUser_GetAllPATs(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: mockUserID, assert.NoError(t, err, "failed to create account")
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{ err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
mockTokenID1: { ID: mockTokenID1,
ID: mockTokenID1, UserID: mockUserID,
HashedToken: mockToken1, HashedToken: mockToken1,
}, })
mockTokenID2: { assert.NoError(t, err, "failed to create PAT")
ID: mockTokenID2,
HashedToken: mockToken2, err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
}, ID: mockTokenID2,
}, UserID: mockUserID,
} HashedToken: mockToken2,
err := store.SaveAccount(context.Background(), account) })
if err != nil { assert.NoError(t, err, "failed to create PAT")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -342,12 +324,9 @@ func validateStruct(s interface{}) (err error) {
func TestUser_CreateServiceUser(t *testing.T) { func TestUser_CreateServiceUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -359,7 +338,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
t.Fatalf("Error when creating service user: %s", err) t.Fatalf("Error when creating service user: %s", err)
} }
account, err = store.GetAccount(context.Background(), mockAccountID) account, err := store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 2, len(account.Users)) assert.Equal(t, 2, len(account.Users))
@ -383,12 +362,9 @@ func TestUser_CreateServiceUser(t *testing.T) {
func TestUser_CreateUser_ServiceUser(t *testing.T) { func TestUser_CreateUser_ServiceUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -406,7 +382,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
t.Fatalf("Error when creating user: %s", err) t.Fatalf("Error when creating user: %s", err)
} }
account, err = store.GetAccount(context.Background(), mockAccountID) account, err := store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, user.IsServiceUser) assert.True(t, user.IsServiceUser)
@ -425,12 +401,9 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
func TestUser_CreateUser_RegularUser(t *testing.T) { func TestUser_CreateUser_RegularUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -450,12 +423,9 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
func TestUser_InviteNewUser(t *testing.T) { func TestUser_InviteNewUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -549,13 +519,12 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = tt.serviceUser
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
} err = store.SaveUser(context.Background(), LockingStrengthUpdate, tt.serviceUser)
assert.NoError(t, err, "failed to create service user")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -582,12 +551,9 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
func TestUser_DeleteUser_SelfDelete(t *testing.T) { func TestUser_DeleteUser_SelfDelete(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -603,39 +569,38 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
func TestUser_DeleteUser_regularUser(t *testing.T) { func TestUser_DeleteUser_regularUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2" err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
account.Users[targetId] = &User{ assert.NoError(t, err, "failed to create account")
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
targetId = "user5" err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
account.Users[targetId] = &User{ {
Id: targetId, Id: "user2",
IsServiceUser: false, AccountID: mockAccountID,
Issued: UserIssuedAPI, IsServiceUser: true,
Role: UserRoleOwner, ServiceUserName: "user2username",
} },
{
err := store.SaveAccount(context.Background(), account) Id: "user3",
if err != nil { AccountID: mockAccountID,
t.Fatalf("Error when saving account: %s", err) IsServiceUser: false,
} Issued: UserIssuedAPI,
},
{
Id: "user4",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedIntegration,
},
{
Id: "user5",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
},
})
assert.NoError(t, err, "failed to save users")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -685,61 +650,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
func TestUser_DeleteUser_RegularUsers(t *testing.T) { func TestUser_DeleteUser_RegularUsers(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2" err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
account.Users[targetId] = &User{ assert.NoError(t, err, "failed to create account")
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
targetId = "user5" err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
account.Users[targetId] = &User{ {
Id: targetId, Id: "user2",
IsServiceUser: false, AccountID: mockAccountID,
Issued: UserIssuedAPI, IsServiceUser: true,
Role: UserRoleOwner, ServiceUserName: "user2username",
} },
account.Users["user6"] = &User{ {
Id: "user6", Id: "user3",
IsServiceUser: false, AccountID: mockAccountID,
Issued: UserIssuedAPI, IsServiceUser: false,
} Issued: UserIssuedAPI,
account.Users["user7"] = &User{ },
Id: "user7", {
IsServiceUser: false, Id: "user4",
Issued: UserIssuedAPI, AccountID: mockAccountID,
} IsServiceUser: false,
account.Users["user8"] = &User{ Issued: UserIssuedIntegration,
Id: "user8", },
IsServiceUser: false, {
Issued: UserIssuedAPI, Id: "user5",
Role: UserRoleAdmin, AccountID: mockAccountID,
} IsServiceUser: false,
account.Users["user9"] = &User{ Issued: UserIssuedAPI,
Id: "user9", Role: UserRoleOwner,
IsServiceUser: false, },
Issued: UserIssuedAPI, {
Role: UserRoleAdmin, Id: "user6",
} AccountID: mockAccountID,
IsServiceUser: false,
err := store.SaveAccount(context.Background(), account) Issued: UserIssuedAPI,
if err != nil { },
t.Fatalf("Error when saving account: %s", err) {
} Id: "user7",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
},
{
Id: "user8",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
},
{
Id: "user9",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
},
})
assert.NoError(t, err)
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -816,7 +784,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
acc, err := am.Store.GetAccount(context.Background(), account.Id) acc, err := am.Store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err) assert.NoError(t, err)
for _, id := range tc.expectedDeleted { for _, id := range tc.expectedDeleted {
@ -836,12 +804,9 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -865,14 +830,19 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) { func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
} newUser := NewRegularUser("normal_user1")
newUser.AccountID = mockAccountID
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
newUser = NewRegularUser("normal_user2")
newUser.AccountID = mockAccountID
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -946,15 +916,24 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
err := store.SaveAccount(context.Background(), account) err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
if err != nil { assert.NoError(t, err, "failed to create account")
t.Fatalf("Error when saving account: %s", err)
} newUser := NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, mockAccountID)
assert.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, mockAccountID, settings)
assert.NoError(t, err, "failed to save account settings")
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, mockAccountID, mockUserID)
assert.NoError(t, err, "failed to delete user")
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -968,7 +947,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
assert.Equal(t, 1, len(users)) assert.Equal(t, 1, len(users))
userInfo, _ := users[0].ToUserInfo(nil, account.Settings) userInfo, _ := users[0].ToUserInfo(nil, settings)
assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView)
}) })
} }
@ -978,22 +957,21 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
func TestDefaultAccountManager_ExternalCache(t *testing.T) { func TestDefaultAccountManager_ExternalCache(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
externalUser := &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
Id: "externalUser", assert.NoError(t, err, "failed to create account")
Role: UserRoleUser,
Issued: UserIssuedIntegration, err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: "externalUser",
AccountID: mockAccountID,
Role: UserRoleUser,
Issued: UserIssuedIntegration,
IntegrationReference: integration_reference.IntegrationReference{ IntegrationReference: integration_reference.IntegrationReference{
ID: 1, ID: 1,
IntegrationType: "external", IntegrationType: "external",
}, },
} })
account.Users[externalUser.Id] = externalUser assert.NoError(t, err, "failed to create user")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -1013,6 +991,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
cacheManager := am.GetExternalCacheManager() cacheManager := am.GetExternalCacheManager()
externalUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, "externalUser")
assert.NoError(t, err, "failed to get user")
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id) cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}) err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"})
assert.NoError(t, err) assert.NoError(t, err)
@ -1042,17 +1024,17 @@ func TestUser_IsAdmin(t *testing.T) {
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{ err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockServiceUserID, Id: mockServiceUserID,
AccountID: mockAccountID,
Role: "user", Role: "user",
IsServiceUser: true, IsServiceUser: true,
} })
assert.NoError(t, err, "failed to create user")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -1071,17 +1053,16 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
store := newStore(t) store := newStore(t)
defer store.Close(context.Background()) defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{ assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockServiceUserID, Id: mockServiceUserID,
AccountID: mockAccountID,
Role: "user", Role: "user",
IsServiceUser: true, IsServiceUser: true,
} })
assert.NoError(t, err, "failed to create user")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
@ -1240,21 +1221,30 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// create an account and an admin user // create an account and an admin user
account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io") accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), ownerUserID, "netbird.io")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// create other users // create other users
account.Users[regularUserID] = NewRegularUser(regularUserID) regularUser := NewRegularUser(regularUserID)
account.Users[adminUserID] = NewAdminUser(adminUserID) regularUser.AccountID = accountID
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(context.Background(), account) adminUser := NewAdminUser(adminUserID)
if err != nil { adminUser.AccountID = accountID
t.Fatal(err)
serviceUser := &User{
Id: serviceUserID,
AccountID: accountID,
IsServiceUser: true,
Role: UserRoleAdmin,
ServiceUserName: "service",
} }
updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update) err = manager.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{regularUser, adminUser, serviceUser})
assert.NoError(t, err, "failed to save users")
updated, err := manager.SaveUser(context.Background(), accountID, tc.initiatorID, tc.update)
if tc.expectedErr { if tc.expectedErr {
require.Errorf(t, err, "expecting SaveUser to throw an error") require.Errorf(t, err, "expecting SaveUser to throw an error")
} else { } else {