Fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2025-01-15 18:35:37 +03:00
parent a72a331128
commit 7124cf5c94
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
6 changed files with 55 additions and 84 deletions

View File

@ -1043,7 +1043,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
newUser := NewRegularUser(claims.UserId)
newUser := types.NewRegularUser(claims.UserId)
newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
if err != nil {
@ -1121,14 +1121,14 @@ func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, toke
}
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) {
if len(token) != PATLength {
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
if len(token) != types.PATLength {
return nil, nil, fmt.Errorf("token has incorrect length")
}
prefix := token[:len(types.PATPrefix)]
if prefix != types.PATPrefix {
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
return nil, nil, fmt.Errorf("token has wrong prefix")
}
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
@ -1146,10 +1146,10 @@ func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
var user *User
var pat *PersonalAccessToken
var user *types.User
var pat *types.PersonalAccessToken
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
if err != nil {
return err
@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user peers: %w", err)
}
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}

View File

@ -36,7 +36,7 @@ var testAccount = &types.Account{
userID: {
Id: userID,
AccountID: accountID,
PATs: map[string]*server.PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
tokenID: {
ID: tokenID,
Name: "My first token",

View File

@ -15,6 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/server/util"
log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
@ -567,8 +568,8 @@ func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength,
return nil
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
var users []*User
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
var users []*types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -899,7 +900,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
var createdBy string
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
Select("created_by").First(&createdBy, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -2063,8 +2064,8 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
}
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) {
var pat PersonalAccessToken
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
var pat types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -2078,8 +2079,8 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking
}
// GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) {
var pat PersonalAccessToken
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
var pat types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&pat, "id = ? AND user_id = ?", patID, userID)
if err := result.Error; err != nil {
@ -2094,8 +2095,8 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
}
// GetUserPATs retrieves personal access tokens for a user.
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) {
var pats []*PersonalAccessToken
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
var pats []*types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
@ -2107,8 +2108,8 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength
// MarkPATUsed marks a personal access token as used.
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
patCopy := PersonalAccessToken{
LastUsed: time.Now().UTC(),
patCopy := types.PersonalAccessToken{
LastUsed: util.ToPtr(time.Now().UTC()),
}
fieldsToUpdate := []string{"last_used"}
@ -2127,7 +2128,7 @@ func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength
}
// SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)

View File

@ -2916,16 +2916,16 @@ func TestSqlStore_SaveUser(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
user := &User{
user := &types.User{
Id: "user-id",
AccountID: accountID,
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
IsServiceUser: false,
AutoGroups: []string{"groupA", "groupB"},
Blocked: false,
LastLogin: time.Now().UTC(),
LastLogin: util.ToPtr(time.Now().UTC()),
CreatedAt: time.Now().UTC().Add(-time.Hour),
Issued: UserIssuedIntegration,
Issued: types.UserIssuedIntegration,
}
err = store.SaveUser(context.Background(), LockingStrengthUpdate, user)
require.NoError(t, err)
@ -2936,7 +2936,7 @@ func TestSqlStore_SaveUser(t *testing.T) {
require.Equal(t, user.AccountID, saveUser.AccountID)
require.Equal(t, user.Role, saveUser.Role)
require.Equal(t, user.AutoGroups, saveUser.AutoGroups)
require.WithinDurationf(t, user.LastLogin, saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal")
require.WithinDurationf(t, user.GetLastLogin(), saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal")
require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
require.Equal(t, user.Issued, saveUser.Issued)
require.Equal(t, user.Blocked, saveUser.Blocked)
@ -2954,7 +2954,7 @@ func TestSqlStore_SaveUsers(t *testing.T) {
require.NoError(t, err)
require.Len(t, accountUsers, 2)
users := []*User{
users := []*types.User{
{
Id: "user-1",
AccountID: accountID,
@ -3087,15 +3087,15 @@ func TestSqlStore_SavePAT(t *testing.T) {
userID := "edafee4e-63fb-11ec-90d6-0242ac120003"
pat := &PersonalAccessToken{
pat := &types.PersonalAccessToken{
ID: "pat-id",
UserID: userID,
Name: "token",
HashedToken: "SoMeHaShEdToKeN",
ExpirationDate: time.Now().UTC().Add(12 * time.Hour),
ExpirationDate: util.ToPtr(time.Now().UTC().Add(12 * time.Hour)),
CreatedBy: userID,
CreatedAt: time.Now().UTC().Add(time.Hour),
LastUsed: time.Now().UTC().Add(-15 * time.Minute),
LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)),
}
err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat)
require.NoError(t, err)
@ -3106,9 +3106,9 @@ func TestSqlStore_SavePAT(t *testing.T) {
require.Equal(t, pat.UserID, savePAT.UserID)
require.Equal(t, pat.HashedToken, savePAT.HashedToken)
require.Equal(t, pat.CreatedBy, savePAT.CreatedBy)
require.WithinDurationf(t, pat.ExpirationDate, savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal")
require.WithinDurationf(t, pat.GetExpirationDate(), savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal")
require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
require.WithinDurationf(t, pat.LastUsed, savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal")
require.WithinDurationf(t, pat.GetLastUsed(), savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal")
}
func TestSqlStore_DeletePAT(t *testing.T) {

View File

@ -139,7 +139,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
newUser := &types.User{
Id: idpUser.ID,
AccountID: accountID,
Role: StrRoleToUserRole(invite.Role),
Role: types.StrRoleToUserRole(invite.Role),
AutoGroups: invite.AutoGroups,
Issued: invite.Issued,
IntegrationReference: invite.IntegrationReference,
@ -489,20 +489,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var updateAccountPeers bool
var peersToExpire []*nbpeer.Peer
var addUserEvents []func()
var usersToSave = make([]*User, 0, len(updates))
var updatedUsersInfo = make([]*UserInfo, 0, len(updates))
var usersToSave = make([]*types.User, 0, len(updates))
var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates))
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err)
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
groupsMap := make(map[string]*types.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, update := range updates {
if update == nil {
return status.Errorf(status.InvalidArgument, "provided user update is nil")
@ -550,14 +550,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return nil, fmt.Errorf("failed to increment network serial: %w", err)
}
am.updateAccountPeers(ctx, accountID)
am.UpdateAccountPeers(ctx, accountID)
}
return updatedUsersInfo, nil
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, groupsMap map[string]*types.Group, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() {
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() {
var eventsToStore []func()
if oldUser.IsBlocked() != newUser.IsBlocked() {
@ -586,36 +586,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, gr
return eventsToStore
}
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
var eventsToStore []func()
if newUser.AutoGroups != nil {
removedGroups := utildifference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := utildifference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group, ok := groupsMap[g]
if ok {
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta)
})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID)
}
}
for _, g := range addedGroups {
group, ok := groupsMap[g]
if ok {
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, meta)
})
}
}
}
return eventsToStore
}
func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction Store, groupsMap map[string]*types.Group,
func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction store.Store, groupsMap map[string]*types.Group,
initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
if update == nil {
@ -658,8 +629,8 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti
}
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
removedGroups := util.difference(oldUser.AutoGroups, update.AutoGroups)
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups)
removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups)
if err != nil {
return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err)
}
@ -670,13 +641,13 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti
}
updateAccountPeers := len(userPeers) > 0
userEventsToAdd := am.prepareUserUpdateEvents(ctx, groupsMap, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole)
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole)
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
}
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update *User, addIfNotExists bool) (*User, error) {
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, update *types.User, addIfNotExists bool) (*types.User, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
@ -690,8 +661,8 @@ func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update *
return existingUser, nil
}
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *User) (bool, error) {
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) {
if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = types.UserRoleAdmin
@ -713,6 +684,7 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, transaction st
}
if !isNil(am.idpManager) && !user.IsServiceUser {
// TODO: Run lookupUserInCache with transaction
userData, err := am.lookupUserInCache(ctx, user.Id, accountID)
if err != nil {
return nil, err
@ -818,7 +790,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
users := make(map[string]userLoggedInOnce, len(accountUsers))
usersFromIntegration := make([]*idp.UserData, 0)
for _, user := range accountUsers {
if user.Issued == UserIssuedIntegration {
if user.Issued == types.UserIssuedIntegration {
key := user.IntegrationReference.CacheKey(accountID, user.Id)
info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
@ -1059,7 +1031,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var addPeerRemovedEvents []func()
var updateAccountPeers bool
var targetUser *User
var targetUser *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
@ -1100,9 +1072,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
}
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd,
groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return
}

View File

@ -45,7 +45,7 @@ const (
)
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
@ -53,13 +53,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err = store.SaveAccount(context.Background(), account)
err = s.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
Store: s,
eventStore: &activity.InMemoryEventStore{},
}
@ -81,7 +81,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err)
}