fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-09-22 23:44:10 +03:00
parent 8f98adddf6
commit 7601a17150
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
20 changed files with 199 additions and 160 deletions

View File

@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
assert.Equal(t, account.Id, ev.TargetID) assert.Equal(t, account.Id, ev.TargetID)
} }
func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims type initUserParams jwtclaims.AuthorizationClaims
type test struct { type test struct {
@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs { if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed") require.NoError(t, err, "update init user failed")
@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id testCase.inputClaims.AccountId = initAccount.Id
} }
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed") require.NoError(t, err, "support function failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id accountID := initAccount.Id
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain) accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization // as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount = acc initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
} }
t.Run("JWT groups disabled", func(t *testing.T) { t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 1, "only ALL group should exists") require.Len(t, account.Groups, 1, "only ALL group should exists")
}) })
@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed") require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
}) })
@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed") require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 3, "groups should be added to the account") require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{} groupsByNames := map[string]*group.Group{}
@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user" userId := "test_user"
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if account == nil { if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
return return
} }
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
} }
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "") _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
if err == nil { if err == nil {
t.Errorf("expected an error when user and account IDs are empty") t.Errorf("expected an error when user and account IDs are empty")
} }
@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
} }
}() }()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("delete default rule: %v", err) t.Errorf("delete default rule: %v", err)
return return
} }
@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
assert.NotNil(t, account.Settings) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true) require.NoError(t, err, "unable to get account settings")
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
assert.NotNil(t, settings)
assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
} }
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
}) })
@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
}) })
@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}, },
} }
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger // when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account") require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
}) })
@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID") require.NoError(t, err, "unable to get account by ID")
assert.False(t, account.Settings.PeerLoginExpirationEnabled) settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour) require.NoError(t, err, "unable to get account settings")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ assert.False(t, settings.PeerLoginExpirationEnabled)
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Second, PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
}) })
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
}) })

View File

@ -962,6 +962,14 @@ func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (st
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
} }
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ string) (string, string, error) { func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID string) (string, string, error) {
return "", "", status.Errorf(status.Internal, "GetAccountDomainAndCategory is not implemented") s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return "", "", err
}
return account.Domain, account.DomainCategory, nil
} }

View File

@ -23,8 +23,11 @@ import (
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
return &AccountsHandler{ return &AccountsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return account, admin, nil return account.Id, admin.Id, nil
},
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
return account.Settings, nil
}, },
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
halfYearLimit := 180 * 24 * time.Hour halfYearLimit := 180 * 24 * time.Hour

View File

@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
} }
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}, },
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler { func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
return &EventsHandler{ return &EventsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
} }
return []*activity.Event{}, nil return []*activity.Event{}, nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return &server.Account{ return claims.AccountId, claims.UserId, nil
Id: claims.AccountId,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
}, user, nil
}, },
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil return make([]*server.UserInfo, 0), nil
@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) {
accountID := "test_account" accountID := "test_account"
adminUser := server.NewAdminUser("test_user") adminUser := server.NewAdminUser("test_user")
events := generateEvents(accountID, adminUser.Id) events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, adminUser, events...) handler := initEventsTestData(accountID, events...)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -11,9 +11,9 @@ import (
"testing" "testing"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{ return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
user := server.NewAdminUser("test_user") return claims.AccountId, claims.UserId, nil
return &server.Account{ },
Id: claims.AccountId, GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
Users: map[string]*server.User{ return server.NewAdminUser(id), nil
"test_user": user,
},
}, user, nil
}, },
}, },
geolocationManager: geo, geolocationManager: geo,

View File

@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
} }
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{ return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return nil return nil
}, },
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
if groupID != "idofthegroup" { groups := map[string]*nbgroup.Group{
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
}
for _, group := range initGroups {
groups[group.ID] = group
}
group, ok := groups[groupID]
if !ok {
return nil, status.Errorf(status.NotFound, "not found") return nil, status.Errorf(status.NotFound, "not found")
} }
if groupID == "id-jwt-group" {
return &nbgroup.Group{ return group, nil
ID: "id-jwt-group",
Name: "Default Group",
Issued: nbgroup.GroupIssuedJWT,
}, nil
}
return &nbgroup.Group{
ID: "idofthegroup",
Name: "Group",
Issued: nbgroup.GroupIssuedAPI,
}, nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return &server.Account{ return claims.AccountId, claims.UserId, nil
Id: claims.AccountId, },
Domain: "hotmail.com", GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) {
Peers: TestPeers, if groupName == "All" {
Users: map[string]*server.User{ return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil
user.Id: user, }
},
Groups: map[string]*nbgroup.Group{ return nil, fmt.Errorf("unknown group name")
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, },
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, return maps.Values(TestPeers), nil
},
}, user, nil
}, },
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
if groupID == "linked-grp" { if groupID == "linked-grp" {
@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group", Name: "Group",
} }
adminUser := server.NewAdminUser("test_user") p := initGroupTestData(group)
p := initGroupTestData(adminUser, group)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) {
}, },
} }
adminUser := server.NewAdminUser("test_user") p := initGroupTestData()
p := initGroupTestData(adminUser)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) {
}, },
} }
adminUser := server.NewAdminUser("test_user") p := initGroupTestData()
p := initGroupTestData(adminUser)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -18,7 +18,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
@ -29,14 +28,6 @@ const (
testNSGroupAccountID = "test_id" testNSGroupAccountID = "test_id"
) )
var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
var baseExistingNSGroup = &nbdns.NameServerGroup{ var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID, ID: existingNSGroupID,
Name: "super", Name: "super",
@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler {
} }
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
}, },
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return testingNSAccount, testingAccount.Users["test_user"], nil return claims.AccountId, claims.UserId, nil
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -77,8 +77,8 @@ func initPATTestData() *PATHandler {
}, nil }, nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return testAccount, testAccount.Users[existingUserID], nil return claims.AccountId, claims.UserId, nil
}, },
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID { if accountID != existingAccountID {
@ -119,7 +119,7 @@ func initPATTestData() *PATHandler {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: existingUserID, UserId: existingUserID,
Domain: testDomain, Domain: testDomain,
AccountId: testNSGroupAccountID, AccountId: existingAccountID,
} }
}), }),
), ),

View File

@ -13,16 +13,15 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
@ -70,7 +69,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
GetDNSDomainFunc: func() string { GetDNSDomainFunc: func() string {
return "netbird.selfhosted" return "netbird.selfhosted"
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
peersMap := make(map[string]*nbpeer.Peer) peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers { for _, peer := range peers {
peersMap[peer.ID] = peer.Copy() peersMap[peer.ID] = peer.Copy()
@ -78,7 +80,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
policy := &server.Policy{ policy := &server.Policy{
ID: "policy", ID: "policy",
AccountID: claims.AccountId, AccountID: accountID,
Name: "policy", Name: "policy",
Enabled: true, Enabled: true,
Rules: []*server.PolicyRule{ Rules: []*server.PolicyRule{
@ -100,7 +102,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
srvUser.IsServiceUser = true srvUser.IsServiceUser = true
account := &server.Account{ account := &server.Account{
Id: claims.AccountId, Id: accountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: peersMap, Peers: peersMap,
Users: map[string]*server.User{ Users: map[string]*server.User{
@ -111,7 +113,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
Groups: map[string]*nbgroup.Group{ Groups: map[string]*nbgroup.Group{
"group1": { "group1": {
ID: "group1", ID: "group1",
AccountID: claims.AccountId, AccountID: accountID,
Name: "group1", Name: "group1",
Issued: "api", Issued: "api",
Peers: maps.Keys(peersMap), Peers: maps.Keys(peersMap),
@ -132,7 +134,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
}, },
} }
return account, account.Users[claims.UserId], nil return account, nil
}, },
HasConnectedChannelFunc: func(peerID string) bool { HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{}) statuses := make(map[string]struct{})
@ -279,9 +281,15 @@ func TestGetPeers(t *testing.T) {
// hardcode this check for now as we only have two peers in this suite // hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2) assert.Equal(t, len(respBody), 2)
assert.Equal(t, respBody[1].Connected, false)
got = respBody[0] for _, peer := range respBody {
if peer.Id == testPeerID {
got = peer
} else {
assert.Equal(t, peer.Connected, false)
}
}
} else { } else {
got = &api.Peer{} got = &api.Peer{}
err = json.Unmarshal(content, got) err = json.Unmarshal(content, got)

View File

@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
} }
return policy, nil return policy, nil
}, },
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error { SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
if !strings.HasPrefix(policy.ID, "id-") { if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set" policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set" policy.Rules[0].ID = "id-was-set"
} }
return nil return nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
user := server.NewAdminUser("test_user") return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
user := server.NewAdminUser(userID)
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: accountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Policies: []*server.Policy{ Policies: []*server.Policy{
{ID: "id-existed"}, {ID: "id-existed"},
@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
Users: map[string]*server.User{ Users: map[string]*server.User{
"test_user": user, "test_user": user,
}, },
}, user, nil }, nil
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -73,9 +73,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
return return
} }
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, userID, postureChecksID) _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@ -14,7 +14,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
} }
return accountPostureChecks, nil return accountPostureChecks, nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
user := server.NewAdminUser("test_user") return claims.AccountId, claims.UserId, nil
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{
"test_user": user,
},
PostureChecks: postureChecks,
}, user, nil
}, },
}, },
geolocationManager: &geolocation.Geolocation{}, geolocationManager: &geolocation.Geolocation{},

View File

@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler {
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID { if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0]) return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
} }
if peerID != "" {
if peerID == nonLinuxExistingPeerID {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
return &route.Route{ return &route.Route{
ID: existingRouteID, ID: existingRouteID,
NetID: netID, NetID: netID,
@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler {
if r.Peer == notFoundPeerID { if r.Peer == notFoundPeerID {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
} }
if r.Peer == nonLinuxExistingPeerID {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
return nil return nil
}, },
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error { DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler {
} }
return nil return nil
}, },
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
return testingAccount, testingAccount.Users["test_user"], nil //return testingAccount, testingAccount.Users["test_user"], nil
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
) *SetupKeysHandler { ) *SetupKeysHandler {
return &SetupKeysHandler{ return &SetupKeysHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return &server.Account{ return claims.AccountId, claims.UserId, nil
Id: testAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey,
},
Groups: map[string]*nbgroup.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"},
},
}, user, nil
}, },
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool, _ int, _ string, ephemeral bool,

View File

@ -49,7 +49,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
targetUserID := vars["userId"] targetUserID := vars["userId"]
if len(userID) == 0 { if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
@ -79,7 +79,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
} }
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{ newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
Id: userID, Id: targetUserID,
Role: userRole, Role: userRole,
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
Blocked: req.IsBlocked, Blocked: req.IsBlocked,

View File

@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{
func initUsersTestData() *UsersHandler { func initUsersTestData() *UsersHandler {
return &UsersHandler{ return &UsersHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil return usersTestAccount.Id, claims.UserId, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
return usersTestAccount.Users[id], nil
}, },
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0) users := make([]*server.UserInfo, 0)

View File

@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
Action: PolicyTrafficActionAccept, Action: PolicyTrafficActionAccept,
}, },
} }
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
if err != nil { if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err) t.Errorf("expecting rule to be added, got failure %v", err)
return return
@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
} }
policy.Enabled = false policy.Enabled = false
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
if err != nil { if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err) t.Errorf("expecting rule to be added, got failure %v", err)
return return

View File

@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID}
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
require.NoError(t, err) require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)

View File

@ -796,7 +796,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
assert.NoError(t, err)
acc, err := am.Store.GetAccount(context.Background(), accID)
assert.NoError(t, err) assert.NoError(t, err)
for _, id := range tc.expectedDeleted { for _, id := range tc.expectedDeleted {