From 7601a171503040024b53623abd627f93f83be16b Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sun, 22 Sep 2024 23:44:10 +0300 Subject: [PATCH] fix tests Signed-off-by: bcmmbaga --- management/server/account_test.go | 107 ++++++++++++------ management/server/file_store.go | 12 +- .../server/http/accounts_handler_test.go | 7 +- .../server/http/dns_settings_handler_test.go | 4 +- management/server/http/events_handler_test.go | 14 +-- .../server/http/geolocation_handler_test.go | 15 +-- management/server/http/groups_handler_test.go | 65 +++++------ .../server/http/nameservers_handler_test.go | 13 +-- management/server/http/pat_handler_test.go | 6 +- management/server/http/peers_handler_test.go | 28 +++-- .../server/http/policies_handler_test.go | 16 ++- .../server/http/posture_checks_handler.go | 4 +- .../http/posture_checks_handler_test.go | 12 +- management/server/http/routes_handler_test.go | 16 ++- .../server/http/setupkeys_handler_test.go | 18 +-- management/server/http/users_handler.go | 4 +- management/server/http/users_handler_test.go | 7 +- management/server/peer_test.go | 4 +- management/server/route_test.go | 2 +- management/server/user_test.go | 5 +- 20 files changed, 199 insertions(+), 160 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index 03b5fa83e..303261bea 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { assert.Equal(t, account.Id, ev.TargetID) } -func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { +func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims type test struct { @@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get init account failed") + if testCase.inputUpdateAttrs { err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") @@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) + accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims) require.NoError(t, err, "support function failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) @@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "unable to create account manager") accountID := initAccount.Id - acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain) + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountByUserOrAccountID where the id is getting generated + // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it - initAccount = acc + initAccount, err = manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get init account failed") claims := jwtclaims.AuthorizationClaims{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount @@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { } t.Run("JWT groups disabled", func(t *testing.T) { - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 1, "only ALL group should exists") }) @@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") }) @@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 3, "groups should be added to the account") groupsByNames := map[string]*group.Group{} @@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") if err != nil { t.Fatal(err) } - if account == nil { + if accountID == "" { t.Fatalf("expected to create an account for a user %s", userId) return } - _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) + t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) } - _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") if err == nil { t.Errorf("expected an error when user and account IDs are empty") } @@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") - assert.NotNil(t, account.Settings) - assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true) - assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour) + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "unable to get account settings") + + assert.NotNil(t, settings) + assert.Equal(t, settings.PeerLoginExpirationEnabled, true) + assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour) } func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + + account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + // when we mark peer as connected, the peer login expiration routine should trigger err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") @@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") @@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") require.NoError(t, err, "unable to get account by ID") - assert.False(t, account.Settings.PeerLoginExpirationEnabled) - assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour) + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "unable to get account settings") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + assert.False(t, settings.PeerLoginExpirationEnabled) + assert.Equal(t, settings.PeerLoginExpiration, time.Hour) + + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, }) diff --git a/management/server/file_store.go b/management/server/file_store.go index 1b61b2a68..84b5547a9 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -962,6 +962,14 @@ func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (st return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") } -func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ string) (string, string, error) { - return "", "", status.Errorf(status.Internal, "GetAccountDomainAndCategory is not implemented") +func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID string) (string, string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return "", "", err + } + + return account.Domain, account.DomainCategory, nil } diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 45c7679e5..cacb3d430 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -23,8 +23,11 @@ import ( func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { return &AccountsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return account, admin, nil + GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return account.Id, admin.Id, nil + }, + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + return account.Settings, nil }, UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { halfYearLimit := 180 * 24 * time.Hour diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go index 897ae63dc..8baea7b15 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/dns_settings_handler_test.go @@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler { } return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") }, - GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil + GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index 8bdd508bf..e525cf2ee 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler { +func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { return &EventsHandler{ accountManager: &mock_server.MockAccountManager{ GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { @@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E } return []*activity.Event{}, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Users: map[string]*server.User{ - user.Id: user, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { return make([]*server.UserInfo, 0), nil @@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) { accountID := "test_account" adminUser := server.NewAdminUser("test_user") events := generateEvents(accountID, adminUser.Id) - handler := initEventsTestData(accountID, adminUser, events...) + handler := initEventsTestData(accountID, events...) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/geolocation_handler_test.go index 7f4d6dc7c..19c916dd2 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/geolocation_handler_test.go @@ -11,9 +11,9 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { return &GeolocationsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Users: map[string]*server.User{ - "test_user": user, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + return server.NewAdminUser(id), nil }, }, geolocationManager: geo, diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index d5ed07c9e..7f3c81f18 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/mux" "github.com/magiconair/properties/assert" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" @@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { +func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { @@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { return nil }, GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { - if groupID != "idofthegroup" { + groups := map[string]*nbgroup.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, + } + + for _, group := range initGroups { + groups[group.ID] = group + } + + group, ok := groups[groupID] + if !ok { return nil, status.Errorf(status.NotFound, "not found") } - if groupID == "id-jwt-group" { - return &nbgroup.Group{ - ID: "id-jwt-group", - Name: "Default Group", - Issued: nbgroup.GroupIssuedJWT, - }, nil - } - return &nbgroup.Group{ - ID: "idofthegroup", - Name: "Group", - Issued: nbgroup.GroupIssuedAPI, - }, nil + + return group, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Peers: TestPeers, - Users: map[string]*server.User{ - user.Id: user, - }, - Groups: map[string]*nbgroup.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) { + if groupName == "All" { + return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil + } + + return nil, fmt.Errorf("unknown group name") + }, + GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + return maps.Values(TestPeers), nil }, DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { if groupID == "linked-grp" { @@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) { Name: "Group", } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser, group) + p := initGroupTestData(group) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) { }, } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser) + p := initGroupTestData() for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) { }, } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser) + p := initGroupTestData() for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 28b080571..98c2e402d 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -18,7 +18,6 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -29,14 +28,6 @@ const ( testNSGroupAccountID = "test_id" ) -var testingNSAccount = &server.Account{ - Id: testNSGroupAccountID, - Domain: "hotmail.com", - Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), - }, -} - var baseExistingNSGroup = &nbdns.NameServerGroup{ ID: existingNSGroupID, Name: "super", @@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler { } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingNSAccount, testingAccount.Users["test_user"], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index b72f71468..c28228a50 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -77,8 +77,8 @@ func initPATTestData() *PATHandler { }, nil }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testAccount, testAccount.Users[existingUserID], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { @@ -119,7 +119,7 @@ func initPATTestData() *PATHandler { return jwtclaims.AuthorizationClaims{ UserId: existingUserID, Domain: testDomain, - AccountId: testNSGroupAccountID, + AccountId: existingAccountID, } }), ), diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index dae264fff..f933eee14 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -13,16 +13,15 @@ import ( "time" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -70,7 +69,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { peersMap[peer.ID] = peer.Copy() @@ -78,7 +80,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { policy := &server.Policy{ ID: "policy", - AccountID: claims.AccountId, + AccountID: accountID, Name: "policy", Enabled: true, Rules: []*server.PolicyRule{ @@ -100,7 +102,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { srvUser.IsServiceUser = true account := &server.Account{ - Id: claims.AccountId, + Id: accountID, Domain: "hotmail.com", Peers: peersMap, Users: map[string]*server.User{ @@ -111,7 +113,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { Groups: map[string]*nbgroup.Group{ "group1": { ID: "group1", - AccountID: claims.AccountId, + AccountID: accountID, Name: "group1", Issued: "api", Peers: maps.Keys(peersMap), @@ -132,7 +134,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, } - return account, account.Users[claims.UserId], nil + return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { statuses := make(map[string]struct{}) @@ -279,9 +281,15 @@ func TestGetPeers(t *testing.T) { // hardcode this check for now as we only have two peers in this suite assert.Equal(t, len(respBody), 2) - assert.Equal(t, respBody[1].Connected, false) - got = respBody[0] + for _, peer := range respBody { + if peer.Id == testPeerID { + got = peer + } else { + assert.Equal(t, peer.Connected, false) + } + } + } else { got = &api.Peer{} err = json.Unmarshal(content, got) diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 06274fb07..228ebcbce 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } return nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { + return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil + }, + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + user := server.NewAdminUser(userID) return &server.Account{ - Id: claims.AccountId, + Id: accountID, Domain: "hotmail.com", Policies: []*server.Policy{ {ID: "id-existed"}, @@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { Users: map[string]*server.User{ "test_user": user, }, - }, user, nil + }, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 0ab2b3a88..1d020e9bc 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -73,9 +73,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http return } - _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, userID, postureChecksID) + _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) if err != nil { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) + util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 974edafde..02f0f0d83 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -14,7 +14,6 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return accountPostureChecks, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Users: map[string]*server.User{ - "test_user": user, - }, - PostureChecks: postureChecks, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, }, geolocationManager: &geolocation.Geolocation{}, diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 40075eb9d..2c367cac3 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler { if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID { return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0]) } + if peerID != "" { + if peerID == nonLinuxExistingPeerID { + return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + return &route.Route{ ID: existingRouteID, NetID: netID, @@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler { if r.Peer == notFoundPeerID { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer) } + + if r.Peer == nonLinuxExistingPeerID { + return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + return nil }, DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error { @@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler { } return nil }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingAccount, testingAccount.Users["test_user"], nil + GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + //return testingAccount, testingAccount.Users["test_user"], nil + return testingAccount.Id, testingAccount.Users["test_user"].Id, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index bfa0ec008..2d15287af 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" @@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup ) *SetupKeysHandler { return &SetupKeysHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: testAccountID, - Domain: "hotmail.com", - Users: map[string]*server.User{ - user.Id: user, - }, - SetupKeys: map[string]*server.SetupKey{ - defaultKey.Key: defaultKey, - }, - Groups: map[string]*nbgroup.Group{ - "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, - "id-all": {ID: "id-all", Name: "All"}, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index e36b11729..6e151a0da 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -49,7 +49,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) targetUserID := vars["userId"] - if len(userID) == 0 { + if len(targetUserID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } @@ -79,7 +79,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { } newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{ - Id: userID, + Id: targetUserID, Role: userRole, AutoGroups: req.AutoGroups, Blocked: req.IsBlocked, diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index a78ac3a4e..f3d989da1 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{ func initUsersTestData() *UsersHandler { return &UsersHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return usersTestAccount, usersTestAccount.Users[claims.UserId], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return usersTestAccount.Id, claims.UserId, nil + }, + GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + return usersTestAccount.Users[id], nil }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { users := make([]*server.UserInfo, 0) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 4b2ec66c6..d329e04bc 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { Action: PolicyTrafficActionAccept, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return diff --git a/management/server/route_test.go b/management/server/route_test.go index 506bfb0a8..4533c6b7e 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) + err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) require.NoError(t, err) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) diff --git a/management/server/user_test.go b/management/server/user_test.go index 272060276..28284f517 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -796,7 +796,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") + assert.NoError(t, err) + + acc, err := am.Store.GetAccount(context.Background(), accID) assert.NoError(t, err) for _, id := range tc.expectedDeleted {