fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-09-22 23:44:10 +03:00
parent 8f98adddf6
commit 7601a17150
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
20 changed files with 199 additions and 160 deletions

View File

@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
assert.Equal(t, account.Id, ev.TargetID)
}
func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims
type test struct {
@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount = acc
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}
t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{}
@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
if err != nil {
t.Fatal(err)
}
if account == nil {
if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId)
return
}
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
}
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
assert.NotNil(t, account.Settings)
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings)
assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
}
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
assert.False(t, settings.PeerLoginExpirationEnabled)
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
})

View File

@ -962,6 +962,14 @@ func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (st
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
}
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ string) (string, string, error) {
return "", "", status.Errorf(status.Internal, "GetAccountDomainAndCategory is not implemented")
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID string) (string, string, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return "", "", err
}
return account.Domain, account.DomainCategory, nil
}

View File

@ -23,8 +23,11 @@ import (
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
return &AccountsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return account, admin, nil
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return account.Id, admin.Id, nil
},
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
return account.Settings, nil
},
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
halfYearLimit := 180 * 24 * time.Hour

View File

@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
}
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
},
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
return &EventsHandler{
accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
}
return []*activity.Event{}, nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
}, user, nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil
@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) {
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, adminUser, events...)
handler := initEventsTestData(accountID, events...)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {

View File

@ -11,9 +11,9 @@ import (
"testing"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{
"test_user": user,
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
}, user, nil
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
return server.NewAdminUser(id), nil
},
},
geolocationManager: geo,

View File

@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux"
"github.com/magiconair/properties/assert"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
}
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return nil
},
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
if groupID != "idofthegroup" {
return nil, status.Errorf(status.NotFound, "not found")
}
if groupID == "id-jwt-group" {
return &nbgroup.Group{
ID: "id-jwt-group",
Name: "Default Group",
Issued: nbgroup.GroupIssuedJWT,
}, nil
}
return &nbgroup.Group{
ID: "idofthegroup",
Name: "Group",
Issued: nbgroup.GroupIssuedAPI,
}, nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: TestPeers,
Users: map[string]*server.User{
user.Id: user,
},
Groups: map[string]*nbgroup.Group{
groups := map[string]*nbgroup.Group{
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
}
for _, group := range initGroups {
groups[group.ID] = group
}
group, ok := groups[groupID]
if !ok {
return nil, status.Errorf(status.NotFound, "not found")
}
return group, nil
},
}, user, nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) {
if groupName == "All" {
return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil
}
return nil, fmt.Errorf("unknown group name")
},
GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return maps.Values(TestPeers), nil
},
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
if groupID == "linked-grp" {
@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser, group)
p := initGroupTestData(group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {

View File

@ -18,7 +18,6 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@ -29,14 +28,6 @@ const (
testNSGroupAccountID = "test_id"
)
var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: "super",
@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler {
}
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, testingAccount.Users["test_user"], nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -77,8 +77,8 @@ func initPATTestData() *PATHandler {
}, nil
},
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testAccount, testAccount.Users[existingUserID], nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
@ -119,7 +119,7 @@ func initPATTestData() *PATHandler {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: testDomain,
AccountId: testNSGroupAccountID,
AccountId: existingAccountID,
}
}),
),

View File

@ -13,16 +13,15 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@ -70,7 +69,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
@ -78,7 +80,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
policy := &server.Policy{
ID: "policy",
AccountID: claims.AccountId,
AccountID: accountID,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
@ -100,7 +102,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
srvUser.IsServiceUser = true
account := &server.Account{
Id: claims.AccountId,
Id: accountID,
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
@ -111,7 +113,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: claims.AccountId,
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
@ -132,7 +134,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
}
return account, account.Users[claims.UserId], nil
return account, nil
},
HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{})
@ -279,9 +281,15 @@ func TestGetPeers(t *testing.T) {
// hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2)
assert.Equal(t, respBody[1].Connected, false)
got = respBody[0]
for _, peer := range respBody {
if peer.Id == testPeerID {
got = peer
} else {
assert.Equal(t, peer.Connected, false)
}
}
} else {
got = &api.Peer{}
err = json.Unmarshal(content, got)

View File

@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return policy, nil
},
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
return nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
user := server.NewAdminUser(userID)
return &server.Account{
Id: claims.AccountId,
Id: accountID,
Domain: "hotmail.com",
Policies: []*server.Policy{
{ID: "id-existed"},
@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
Users: map[string]*server.User{
"test_user": user,
},
}, user, nil
}, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -73,9 +73,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
return
}
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, userID, postureChecksID)
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
util.WriteError(r.Context(), err, w)
return
}

View File

@ -14,7 +14,6 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}
return accountPostureChecks, nil
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{
"test_user": user,
},
PostureChecks: postureChecks,
}, user, nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
},
geolocationManager: &geolocation.Geolocation{},

View File

@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler {
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
}
if peerID != "" {
if peerID == nonLinuxExistingPeerID {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
return &route.Route{
ID: existingRouteID,
NetID: netID,
@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler {
if r.Peer == notFoundPeerID {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
}
if r.Peer == nonLinuxExistingPeerID {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
return nil
},
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler {
}
return nil
},
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, testingAccount.Users["test_user"], nil
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
//return testingAccount, testingAccount.Users["test_user"], nil
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
) *SetupKeysHandler {
return &SetupKeysHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey,
},
Groups: map[string]*nbgroup.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"},
},
}, user, nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool,

View File

@ -49,7 +49,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(userID) == 0 {
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
@ -79,7 +79,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
}
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
Id: userID,
Id: targetUserID,
Role: userRole,
AutoGroups: req.AutoGroups,
Blocked: req.IsBlocked,

View File

@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{
func initUsersTestData() *UsersHandler {
return &UsersHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return usersTestAccount.Id, claims.UserId, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
return usersTestAccount.Users[id], nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)

View File

@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
policy.Enabled = false
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return

View File

@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)

View File

@ -796,7 +796,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err)
}
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
assert.NoError(t, err)
acc, err := am.Store.GetAccount(context.Background(), accID)
assert.NoError(t, err)
for _, id := range tc.expectedDeleted {