mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 19:00:50 +01:00
fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
8f98adddf6
commit
7601a17150
@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
assert.Equal(t, account.Id, ev.TargetID)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
type initUserParams jwtclaims.AuthorizationClaims
|
||||
|
||||
type test struct {
|
||||
@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
|
||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
if testCase.inputUpdateAttrs {
|
||||
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
require.NoError(t, err, "update init user failed")
|
||||
@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
testCase.inputClaims.AccountId = initAccount.Id
|
||||
}
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
|
||||
accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
|
||||
require.NoError(t, err, "support function failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||
|
||||
@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID := initAccount.Id
|
||||
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
|
||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
// as initAccount was created without account id we have to take the id after account initialization
|
||||
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
|
||||
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
|
||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||
initAccount = acc
|
||||
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
||||
@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
|
||||
@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
|
||||
groupsByNames := map[string]*group.Group{}
|
||||
@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
if accountID == "" {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
|
||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
|
||||
}
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
|
||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
|
||||
if err == nil {
|
||||
t.Errorf("expected an error when user and account IDs are empty")
|
||||
}
|
||||
@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
assert.NotNil(t, account.Settings)
|
||||
assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
|
||||
assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.NotNil(t, settings)
|
||||
assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
|
||||
assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
||||
require.NoError(t, err, "unable to get account by ID")
|
||||
|
||||
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Second,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
|
@ -962,6 +962,14 @@ func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (st
|
||||
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ string) (string, string, error) {
|
||||
return "", "", status.Errorf(status.Internal, "GetAccountDomainAndCategory is not implemented")
|
||||
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID string) (string, string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return account.Domain, account.DomainCategory, nil
|
||||
}
|
||||
|
@ -23,8 +23,11 @@ import (
|
||||
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
|
||||
return &AccountsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return account, admin, nil
|
||||
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return account.Id, admin.Id, nil
|
||||
},
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||
return account.Settings, nil
|
||||
},
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
|
@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
|
||||
}
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
},
|
||||
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
|
||||
GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
|
||||
func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
|
||||
return &EventsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
|
||||
}
|
||||
return []*activity.Event{}, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
Users: map[string]*server.User{
|
||||
user.Id: user,
|
||||
},
|
||||
}, user, nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
return make([]*server.UserInfo, 0), nil
|
||||
@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
events := generateEvents(accountID, adminUser.Id)
|
||||
handler := initEventsTestData(accountID, adminUser, events...)
|
||||
handler := initEventsTestData(accountID, events...)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
@ -11,9 +11,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
||||
|
||||
return &GeolocationsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Users: map[string]*server.User{
|
||||
"test_user": user,
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
}, user, nil
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||
return server.NewAdminUser(id), nil
|
||||
},
|
||||
},
|
||||
geolocationManager: geo,
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/magiconair/properties/assert"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{
|
||||
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
||||
}
|
||||
|
||||
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
||||
func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
|
||||
return &GroupsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
|
||||
@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
||||
return nil
|
||||
},
|
||||
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
|
||||
if groupID != "idofthegroup" {
|
||||
return nil, status.Errorf(status.NotFound, "not found")
|
||||
}
|
||||
if groupID == "id-jwt-group" {
|
||||
return &nbgroup.Group{
|
||||
ID: "id-jwt-group",
|
||||
Name: "Default Group",
|
||||
Issued: nbgroup.GroupIssuedJWT,
|
||||
}, nil
|
||||
}
|
||||
return &nbgroup.Group{
|
||||
ID: "idofthegroup",
|
||||
Name: "Group",
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
Peers: TestPeers,
|
||||
Users: map[string]*server.User{
|
||||
user.Id: user,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
groups := map[string]*nbgroup.Group{
|
||||
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
|
||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
|
||||
"id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
|
||||
}
|
||||
|
||||
for _, group := range initGroups {
|
||||
groups[group.ID] = group
|
||||
}
|
||||
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "not found")
|
||||
}
|
||||
|
||||
return group, nil
|
||||
},
|
||||
}, user, nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) {
|
||||
if groupName == "All" {
|
||||
return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown group name")
|
||||
},
|
||||
GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return maps.Values(TestPeers), nil
|
||||
},
|
||||
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
||||
if groupID == "linked-grp" {
|
||||
@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) {
|
||||
Name: "Group",
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
p := initGroupTestData(adminUser, group)
|
||||
p := initGroupTestData(group)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
p := initGroupTestData(adminUser)
|
||||
p := initGroupTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
p := initGroupTestData(adminUser)
|
||||
p := initGroupTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
@ -18,7 +18,6 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
@ -29,14 +28,6 @@ const (
|
||||
testNSGroupAccountID = "test_id"
|
||||
)
|
||||
|
||||
var testingNSAccount = &server.Account{
|
||||
Id: testNSGroupAccountID,
|
||||
Domain: "hotmail.com",
|
||||
Users: map[string]*server.User{
|
||||
"test_user": server.NewAdminUser("test_user"),
|
||||
},
|
||||
}
|
||||
|
||||
var baseExistingNSGroup = &nbdns.NameServerGroup{
|
||||
ID: existingNSGroupID,
|
||||
Name: "super",
|
||||
@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler {
|
||||
}
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingNSAccount, testingAccount.Users["test_user"], nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
@ -77,8 +77,8 @@ func initPATTestData() *PATHandler {
|
||||
}, nil
|
||||
},
|
||||
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testAccount, testAccount.Users[existingUserID], nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
@ -119,7 +119,7 @@ func initPATTestData() *PATHandler {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: testNSGroupAccountID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
|
@ -13,16 +13,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
@ -70,7 +69,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
GetDNSDomainFunc: func() string {
|
||||
return "netbird.selfhosted"
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer.Copy()
|
||||
@ -78,7 +80,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
|
||||
policy := &server.Policy{
|
||||
ID: "policy",
|
||||
AccountID: claims.AccountId,
|
||||
AccountID: accountID,
|
||||
Name: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*server.PolicyRule{
|
||||
@ -100,7 +102,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
srvUser.IsServiceUser = true
|
||||
|
||||
account := &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Peers: peersMap,
|
||||
Users: map[string]*server.User{
|
||||
@ -111,7 +113,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
AccountID: claims.AccountId,
|
||||
AccountID: accountID,
|
||||
Name: "group1",
|
||||
Issued: "api",
|
||||
Peers: maps.Keys(peersMap),
|
||||
@ -132,7 +134,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
},
|
||||
}
|
||||
|
||||
return account, account.Users[claims.UserId], nil
|
||||
return account, nil
|
||||
},
|
||||
HasConnectedChannelFunc: func(peerID string) bool {
|
||||
statuses := make(map[string]struct{})
|
||||
@ -279,9 +281,15 @@ func TestGetPeers(t *testing.T) {
|
||||
|
||||
// hardcode this check for now as we only have two peers in this suite
|
||||
assert.Equal(t, len(respBody), 2)
|
||||
assert.Equal(t, respBody[1].Connected, false)
|
||||
|
||||
got = respBody[0]
|
||||
for _, peer := range respBody {
|
||||
if peer.Id == testPeerID {
|
||||
got = peer
|
||||
} else {
|
||||
assert.Equal(t, peer.Connected, false)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
got = &api.Peer{}
|
||||
err = json.Unmarshal(content, got)
|
||||
|
@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
||||
}
|
||||
return policy, nil
|
||||
},
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
|
||||
if !strings.HasPrefix(policy.ID, "id-") {
|
||||
policy.ID = "id-was-set"
|
||||
policy.Rules[0].ID = "id-was-set"
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||
user := server.NewAdminUser(userID)
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Policies: []*server.Policy{
|
||||
{ID: "id-existed"},
|
||||
@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
||||
Users: map[string]*server.User{
|
||||
"test_user": user,
|
||||
},
|
||||
}, user, nil
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
@ -73,9 +73,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, userID, postureChecksID)
|
||||
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
}
|
||||
return accountPostureChecks, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Users: map[string]*server.User{
|
||||
"test_user": user,
|
||||
},
|
||||
PostureChecks: postureChecks,
|
||||
}, user, nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
},
|
||||
geolocationManager: &geolocation.Geolocation{},
|
||||
|
@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler {
|
||||
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
|
||||
}
|
||||
if peerID != "" {
|
||||
if peerID == nonLinuxExistingPeerID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
}
|
||||
|
||||
return &route.Route{
|
||||
ID: existingRouteID,
|
||||
NetID: netID,
|
||||
@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler {
|
||||
if r.Peer == notFoundPeerID {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
|
||||
}
|
||||
|
||||
if r.Peer == nonLinuxExistingPeerID {
|
||||
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
|
||||
@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingAccount, testingAccount.Users["test_user"], nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
//return testingAccount, testingAccount.Users["test_user"], nil
|
||||
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
) *SetupKeysHandler {
|
||||
return &SetupKeysHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: testAccountID,
|
||||
Domain: "hotmail.com",
|
||||
Users: map[string]*server.User{
|
||||
user.Id: user,
|
||||
},
|
||||
SetupKeys: map[string]*server.SetupKey{
|
||||
defaultKey.Key: defaultKey,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
|
||||
"id-all": {ID: "id-all", Name: "All"},
|
||||
},
|
||||
}, user, nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
||||
_ int, _ string, ephemeral bool,
|
||||
|
@ -49,7 +49,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
@ -79,7 +79,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
|
||||
Id: userID,
|
||||
Id: targetUserID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
Blocked: req.IsBlocked,
|
||||
|
@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{
|
||||
func initUsersTestData() *UsersHandler {
|
||||
return &UsersHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return usersTestAccount.Id, claims.UserId, nil
|
||||
},
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||
return usersTestAccount.Users[id], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
users := make([]*server.UserInfo, 0)
|
||||
|
@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
policy.Enabled = false
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
|
@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
||||
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
||||
|
||||
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
|
||||
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
||||
|
@ -796,7 +796,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), accID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, id := range tc.expectedDeleted {
|
||||
|
Loading…
Reference in New Issue
Block a user