From 7124cf5c94e331393868c66fb786696790f70fee Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 15 Jan 2025 18:35:37 +0300 Subject: [PATCH] Fix merge Signed-off-by: bcmmbaga --- management/server/account.go | 16 ++--- .../http/middleware/auth_middleware_test.go | 2 +- management/server/store/sql_store.go | 25 +++---- management/server/store/sql_store_test.go | 22 +++---- management/server/user.go | 66 +++++-------------- management/server/user_test.go | 8 +-- 6 files changed, 55 insertions(+), 84 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 5c23dba04..0d1f27fad 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1043,7 +1043,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() - newUser := NewRegularUser(claims.UserId) + newUser := types.NewRegularUser(claims.UserId) newUser.AccountID = domainAccountID err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser) if err != nil { @@ -1121,14 +1121,14 @@ func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, toke } // extractPATFromToken validates the token structure and retrieves associated User and PAT. -func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) { - if len(token) != PATLength { +func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) { + if len(token) != types.PATLength { return nil, nil, fmt.Errorf("token has incorrect length") } prefix := token[:len(types.PATPrefix)] if prefix != types.PATPrefix { - return nil, nil, nil, fmt.Errorf("token has wrong prefix") + return nil, nil, fmt.Errorf("token has wrong prefix") } secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] @@ -1146,10 +1146,10 @@ func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - var user *User - var pat *PersonalAccessToken + var user *types.User + var pat *types.PersonalAccessToken - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) if err != nil { return err @@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error getting user peers: %w", err) } - updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) + updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) if err != nil { return fmt.Errorf("error modifying user peers in groups: %w", err) } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 7297e6ced..c1686ed44 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -36,7 +36,7 @@ var testAccount = &types.Account{ userID: { Id: userID, AccountID: accountID, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ tokenID: { ID: tokenID, Name: "My first token", diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index d5cee567f..4a847a253 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/netbirdio/netbird/management/server/util" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -567,8 +568,8 @@ func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, return nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { - var users []*User +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) { + var users []*types.User result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -899,7 +900,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) { var createdBy string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). Select("created_by").First(&createdBy, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -2063,8 +2064,8 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki } // GetPATByHashedToken returns a PersonalAccessToken by its hashed token. -func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) { - var pat PersonalAccessToken +func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -2078,8 +2079,8 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking } // GetPATByID retrieves a personal access token by its ID and user ID. -func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) { - var pat PersonalAccessToken +func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&pat, "id = ? AND user_id = ?", patID, userID) if err := result.Error; err != nil { @@ -2094,8 +2095,8 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, } // GetUserPATs retrieves personal access tokens for a user. -func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) { - var pats []*PersonalAccessToken +func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) { + var pats []*types.PersonalAccessToken result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err) @@ -2107,8 +2108,8 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength // MarkPATUsed marks a personal access token as used. func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error { - patCopy := PersonalAccessToken{ - LastUsed: time.Now().UTC(), + patCopy := types.PersonalAccessToken{ + LastUsed: util.ToPtr(time.Now().UTC()), } fieldsToUpdate := []string{"last_used"} @@ -2127,7 +2128,7 @@ func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength } // SavePAT saves a personal access token to the database. -func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error { +func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err) diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 490586271..056c5b049 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -2916,16 +2916,16 @@ func TestSqlStore_SaveUser(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - user := &User{ + user := &types.User{ Id: "user-id", AccountID: accountID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, IsServiceUser: false, AutoGroups: []string{"groupA", "groupB"}, Blocked: false, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), CreatedAt: time.Now().UTC().Add(-time.Hour), - Issued: UserIssuedIntegration, + Issued: types.UserIssuedIntegration, } err = store.SaveUser(context.Background(), LockingStrengthUpdate, user) require.NoError(t, err) @@ -2936,7 +2936,7 @@ func TestSqlStore_SaveUser(t *testing.T) { require.Equal(t, user.AccountID, saveUser.AccountID) require.Equal(t, user.Role, saveUser.Role) require.Equal(t, user.AutoGroups, saveUser.AutoGroups) - require.WithinDurationf(t, user.LastLogin, saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal") + require.WithinDurationf(t, user.GetLastLogin(), saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal") require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") require.Equal(t, user.Issued, saveUser.Issued) require.Equal(t, user.Blocked, saveUser.Blocked) @@ -2954,7 +2954,7 @@ func TestSqlStore_SaveUsers(t *testing.T) { require.NoError(t, err) require.Len(t, accountUsers, 2) - users := []*User{ + users := []*types.User{ { Id: "user-1", AccountID: accountID, @@ -3087,15 +3087,15 @@ func TestSqlStore_SavePAT(t *testing.T) { userID := "edafee4e-63fb-11ec-90d6-0242ac120003" - pat := &PersonalAccessToken{ + pat := &types.PersonalAccessToken{ ID: "pat-id", UserID: userID, Name: "token", HashedToken: "SoMeHaShEdToKeN", - ExpirationDate: time.Now().UTC().Add(12 * time.Hour), + ExpirationDate: util.ToPtr(time.Now().UTC().Add(12 * time.Hour)), CreatedBy: userID, CreatedAt: time.Now().UTC().Add(time.Hour), - LastUsed: time.Now().UTC().Add(-15 * time.Minute), + LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)), } err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat) require.NoError(t, err) @@ -3106,9 +3106,9 @@ func TestSqlStore_SavePAT(t *testing.T) { require.Equal(t, pat.UserID, savePAT.UserID) require.Equal(t, pat.HashedToken, savePAT.HashedToken) require.Equal(t, pat.CreatedBy, savePAT.CreatedBy) - require.WithinDurationf(t, pat.ExpirationDate, savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal") + require.WithinDurationf(t, pat.GetExpirationDate(), savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal") require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") - require.WithinDurationf(t, pat.LastUsed, savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal") + require.WithinDurationf(t, pat.GetLastUsed(), savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal") } func TestSqlStore_DeletePAT(t *testing.T) { diff --git a/management/server/user.go b/management/server/user.go index 22cd785eb..1e0e28122 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -139,7 +139,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u newUser := &types.User{ Id: idpUser.ID, AccountID: accountID, - Role: StrRoleToUserRole(invite.Role), + Role: types.StrRoleToUserRole(invite.Role), AutoGroups: invite.AutoGroups, Issued: invite.Issued, IntegrationReference: invite.IntegrationReference, @@ -489,20 +489,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, var updateAccountPeers bool var peersToExpire []*nbpeer.Peer var addUserEvents []func() - var usersToSave = make([]*User, 0, len(updates)) - var updatedUsersInfo = make([]*UserInfo, 0, len(updates)) + var usersToSave = make([]*types.User, 0, len(updates)) + var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates)) groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, fmt.Errorf("error getting account groups: %w", err) } - groupsMap := make(map[string]*nbgroup.Group, len(groups)) + groupsMap := make(map[string]*types.Group, len(groups)) for _, group := range groups { groupsMap[group.ID] = group } - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, update := range updates { if update == nil { return status.Errorf(status.InvalidArgument, "provided user update is nil") @@ -550,14 +550,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return nil, fmt.Errorf("failed to increment network serial: %w", err) } - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return updatedUsersInfo, nil } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, groupsMap map[string]*types.Group, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { @@ -586,36 +586,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, gr return eventsToStore } -func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { - var eventsToStore []func() - if newUser.AutoGroups != nil { - removedGroups := utildifference(oldUser.AutoGroups, newUser.AutoGroups) - addedGroups := utildifference(newUser.AutoGroups, oldUser.AutoGroups) - for _, g := range removedGroups { - group, ok := groupsMap[g] - if ok { - eventsToStore = append(eventsToStore, func() { - meta := map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName} - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) - }) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID) - } - } - for _, g := range addedGroups { - group, ok := groupsMap[g] - if ok { - eventsToStore = append(eventsToStore, func() { - meta := map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName} - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, meta) - }) - } - } - } - return eventsToStore -} - -func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction Store, groupsMap map[string]*types.Group, +func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction store.Store, groupsMap map[string]*types.Group, initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) { if update == nil { @@ -658,8 +629,8 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti } if update.AutoGroups != nil && settings.GroupsPropagationEnabled { - removedGroups := util.difference(oldUser.AutoGroups, update.AutoGroups) - updatedGroups, err := am.updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups) + removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) + updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups) if err != nil { return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err) } @@ -670,13 +641,13 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti } updateAccountPeers := len(userPeers) > 0 - userEventsToAdd := am.prepareUserUpdateEvents(ctx, groupsMap, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole) + userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole) return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil } // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. -func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update *User, addIfNotExists bool) (*User, error) { +func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, update *types.User, addIfNotExists bool) (*types.User, error) { existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id) if err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { @@ -690,8 +661,8 @@ func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update * return existingUser, nil } -func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *User) (bool, error) { - if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner { +func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) { + if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner { newInitiatorUser := initiatorUser.Copy() newInitiatorUser.Role = types.UserRoleAdmin @@ -713,6 +684,7 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, transaction st } if !isNil(am.idpManager) && !user.IsServiceUser { + // TODO: Run lookupUserInCache with transaction userData, err := am.lookupUserInCache(ctx, user.Id, accountID) if err != nil { return nil, err @@ -818,7 +790,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun users := make(map[string]userLoggedInOnce, len(accountUsers)) usersFromIntegration := make([]*idp.UserData, 0) for _, user := range accountUsers { - if user.Issued == UserIssuedIntegration { + if user.Issued == types.UserIssuedIntegration { key := user.IntegrationReference.CacheKey(accountID, user.Id) info, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { @@ -1059,7 +1031,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI var addPeerRemovedEvents []func() var updateAccountPeers bool - var targetUser *User + var targetUser *types.User err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) @@ -1100,9 +1072,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. -func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, - groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { - +func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return } diff --git a/management/server/user_test.go b/management/server/user_test.go index e9889e56b..5c4b1e2cb 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -45,7 +45,7 @@ const ( ) func TestUser_CreatePAT_ForSameUser(t *testing.T) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) if err != nil { t.Fatalf("Error when creating store: %s", err) } @@ -53,13 +53,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err = store.SaveAccount(context.Background(), account) + err = s.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } am := DefaultAccountManager{ - Store: store, + Store: s, eventStore: &activity.InMemoryEventStore{}, } @@ -81,7 +81,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { assert.Equal(t, pat.ID, tokenID) - user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, tokenID) + user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID) if err != nil { t.Fatalf("Error when getting user by token ID: %s", err) }