Fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2025-01-15 18:35:37 +03:00
parent a72a331128
commit 7124cf5c94
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
6 changed files with 55 additions and 84 deletions

View File

@ -1043,7 +1043,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount() defer unlockAccount()
newUser := NewRegularUser(claims.UserId) newUser := types.NewRegularUser(claims.UserId)
newUser.AccountID = domainAccountID newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser) err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
if err != nil { if err != nil {
@ -1121,14 +1121,14 @@ func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, toke
} }
// extractPATFromToken validates the token structure and retrieves associated User and PAT. // extractPATFromToken validates the token structure and retrieves associated User and PAT.
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) { func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
if len(token) != PATLength { if len(token) != types.PATLength {
return nil, nil, fmt.Errorf("token has incorrect length") return nil, nil, fmt.Errorf("token has incorrect length")
} }
prefix := token[:len(types.PATPrefix)] prefix := token[:len(types.PATPrefix)]
if prefix != types.PATPrefix { if prefix != types.PATPrefix {
return nil, nil, nil, fmt.Errorf("token has wrong prefix") return nil, nil, fmt.Errorf("token has wrong prefix")
} }
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
@ -1146,10 +1146,10 @@ func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token
hashedToken := sha256.Sum256([]byte(token)) hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
var user *User var user *types.User
var pat *PersonalAccessToken var pat *types.PersonalAccessToken
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
if err != nil { if err != nil {
return err return err
@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user peers: %w", err) return fmt.Errorf("error getting user peers: %w", err)
} }
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil { if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err) return fmt.Errorf("error modifying user peers in groups: %w", err)
} }

View File

@ -36,7 +36,7 @@ var testAccount = &types.Account{
userID: { userID: {
Id: userID, Id: userID,
AccountID: accountID, AccountID: accountID,
PATs: map[string]*server.PersonalAccessToken{ PATs: map[string]*types.PersonalAccessToken{
tokenID: { tokenID: {
ID: tokenID, ID: tokenID,
Name: "My first token", Name: "My first token",

View File

@ -15,6 +15,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/management/server/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
@ -567,8 +568,8 @@ func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength,
return nil return nil
} }
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
var users []*User var users []*types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -899,7 +900,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) { func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
var createdBy string var createdBy string
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
Select("created_by").First(&createdBy, idQueryCondition, accountID) Select("created_by").First(&createdBy, idQueryCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -2063,8 +2064,8 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
} }
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token. // GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) { func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
var pat PersonalAccessToken var pat types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -2078,8 +2079,8 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking
} }
// GetPATByID retrieves a personal access token by its ID and user ID. // GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) { func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
var pat PersonalAccessToken var pat types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&pat, "id = ? AND user_id = ?", patID, userID) First(&pat, "id = ? AND user_id = ?", patID, userID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
@ -2094,8 +2095,8 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
} }
// GetUserPATs retrieves personal access tokens for a user. // GetUserPATs retrieves personal access tokens for a user.
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) { func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
var pats []*PersonalAccessToken var pats []*types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err) log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
@ -2107,8 +2108,8 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength
// MarkPATUsed marks a personal access token as used. // MarkPATUsed marks a personal access token as used.
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error { func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
patCopy := PersonalAccessToken{ patCopy := types.PersonalAccessToken{
LastUsed: time.Now().UTC(), LastUsed: util.ToPtr(time.Now().UTC()),
} }
fieldsToUpdate := []string{"last_used"} fieldsToUpdate := []string{"last_used"}
@ -2127,7 +2128,7 @@ func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength
} }
// SavePAT saves a personal access token to the database. // SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error { func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err) log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)

View File

@ -2916,16 +2916,16 @@ func TestSqlStore_SaveUser(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
user := &User{ user := &types.User{
Id: "user-id", Id: "user-id",
AccountID: accountID, AccountID: accountID,
Role: UserRoleAdmin, Role: types.UserRoleAdmin,
IsServiceUser: false, IsServiceUser: false,
AutoGroups: []string{"groupA", "groupB"}, AutoGroups: []string{"groupA", "groupB"},
Blocked: false, Blocked: false,
LastLogin: time.Now().UTC(), LastLogin: util.ToPtr(time.Now().UTC()),
CreatedAt: time.Now().UTC().Add(-time.Hour), CreatedAt: time.Now().UTC().Add(-time.Hour),
Issued: UserIssuedIntegration, Issued: types.UserIssuedIntegration,
} }
err = store.SaveUser(context.Background(), LockingStrengthUpdate, user) err = store.SaveUser(context.Background(), LockingStrengthUpdate, user)
require.NoError(t, err) require.NoError(t, err)
@ -2936,7 +2936,7 @@ func TestSqlStore_SaveUser(t *testing.T) {
require.Equal(t, user.AccountID, saveUser.AccountID) require.Equal(t, user.AccountID, saveUser.AccountID)
require.Equal(t, user.Role, saveUser.Role) require.Equal(t, user.Role, saveUser.Role)
require.Equal(t, user.AutoGroups, saveUser.AutoGroups) require.Equal(t, user.AutoGroups, saveUser.AutoGroups)
require.WithinDurationf(t, user.LastLogin, saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal") require.WithinDurationf(t, user.GetLastLogin(), saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal")
require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
require.Equal(t, user.Issued, saveUser.Issued) require.Equal(t, user.Issued, saveUser.Issued)
require.Equal(t, user.Blocked, saveUser.Blocked) require.Equal(t, user.Blocked, saveUser.Blocked)
@ -2954,7 +2954,7 @@ func TestSqlStore_SaveUsers(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, accountUsers, 2) require.Len(t, accountUsers, 2)
users := []*User{ users := []*types.User{
{ {
Id: "user-1", Id: "user-1",
AccountID: accountID, AccountID: accountID,
@ -3087,15 +3087,15 @@ func TestSqlStore_SavePAT(t *testing.T) {
userID := "edafee4e-63fb-11ec-90d6-0242ac120003" userID := "edafee4e-63fb-11ec-90d6-0242ac120003"
pat := &PersonalAccessToken{ pat := &types.PersonalAccessToken{
ID: "pat-id", ID: "pat-id",
UserID: userID, UserID: userID,
Name: "token", Name: "token",
HashedToken: "SoMeHaShEdToKeN", HashedToken: "SoMeHaShEdToKeN",
ExpirationDate: time.Now().UTC().Add(12 * time.Hour), ExpirationDate: util.ToPtr(time.Now().UTC().Add(12 * time.Hour)),
CreatedBy: userID, CreatedBy: userID,
CreatedAt: time.Now().UTC().Add(time.Hour), CreatedAt: time.Now().UTC().Add(time.Hour),
LastUsed: time.Now().UTC().Add(-15 * time.Minute), LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)),
} }
err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat) err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat)
require.NoError(t, err) require.NoError(t, err)
@ -3106,9 +3106,9 @@ func TestSqlStore_SavePAT(t *testing.T) {
require.Equal(t, pat.UserID, savePAT.UserID) require.Equal(t, pat.UserID, savePAT.UserID)
require.Equal(t, pat.HashedToken, savePAT.HashedToken) require.Equal(t, pat.HashedToken, savePAT.HashedToken)
require.Equal(t, pat.CreatedBy, savePAT.CreatedBy) require.Equal(t, pat.CreatedBy, savePAT.CreatedBy)
require.WithinDurationf(t, pat.ExpirationDate, savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal") require.WithinDurationf(t, pat.GetExpirationDate(), savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal")
require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
require.WithinDurationf(t, pat.LastUsed, savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal") require.WithinDurationf(t, pat.GetLastUsed(), savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal")
} }
func TestSqlStore_DeletePAT(t *testing.T) { func TestSqlStore_DeletePAT(t *testing.T) {

View File

@ -139,7 +139,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
newUser := &types.User{ newUser := &types.User{
Id: idpUser.ID, Id: idpUser.ID,
AccountID: accountID, AccountID: accountID,
Role: StrRoleToUserRole(invite.Role), Role: types.StrRoleToUserRole(invite.Role),
AutoGroups: invite.AutoGroups, AutoGroups: invite.AutoGroups,
Issued: invite.Issued, Issued: invite.Issued,
IntegrationReference: invite.IntegrationReference, IntegrationReference: invite.IntegrationReference,
@ -489,20 +489,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var updateAccountPeers bool var updateAccountPeers bool
var peersToExpire []*nbpeer.Peer var peersToExpire []*nbpeer.Peer
var addUserEvents []func() var addUserEvents []func()
var usersToSave = make([]*User, 0, len(updates)) var usersToSave = make([]*types.User, 0, len(updates))
var updatedUsersInfo = make([]*UserInfo, 0, len(updates)) var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates))
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
} }
groupsMap := make(map[string]*nbgroup.Group, len(groups)) groupsMap := make(map[string]*types.Group, len(groups))
for _, group := range groups { for _, group := range groups {
groupsMap[group.ID] = group groupsMap[group.ID] = group
} }
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, update := range updates { for _, update := range updates {
if update == nil { if update == nil {
return status.Errorf(status.InvalidArgument, "provided user update is nil") return status.Errorf(status.InvalidArgument, "provided user update is nil")
@ -550,14 +550,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return nil, fmt.Errorf("failed to increment network serial: %w", err) return nil, fmt.Errorf("failed to increment network serial: %w", err)
} }
am.updateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
return updatedUsersInfo, nil return updatedUsersInfo, nil
} }
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, groupsMap map[string]*types.Group, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() { func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() {
var eventsToStore []func() var eventsToStore []func()
if oldUser.IsBlocked() != newUser.IsBlocked() { if oldUser.IsBlocked() != newUser.IsBlocked() {
@ -586,36 +586,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, gr
return eventsToStore return eventsToStore
} }
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction store.Store, groupsMap map[string]*types.Group,
var eventsToStore []func()
if newUser.AutoGroups != nil {
removedGroups := utildifference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := utildifference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group, ok := groupsMap[g]
if ok {
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta)
})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID)
}
}
for _, g := range addedGroups {
group, ok := groupsMap[g]
if ok {
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, meta)
})
}
}
}
return eventsToStore
}
func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction Store, groupsMap map[string]*types.Group,
initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) { initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
if update == nil { if update == nil {
@ -658,8 +629,8 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti
} }
if update.AutoGroups != nil && settings.GroupsPropagationEnabled { if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
removedGroups := util.difference(oldUser.AutoGroups, update.AutoGroups) removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups) updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups)
if err != nil { if err != nil {
return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err) return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err)
} }
@ -670,13 +641,13 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti
} }
updateAccountPeers := len(userPeers) > 0 updateAccountPeers := len(userPeers) > 0
userEventsToAdd := am.prepareUserUpdateEvents(ctx, groupsMap, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole) userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole)
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
} }
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update *User, addIfNotExists bool) (*User, error) { func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, update *types.User, addIfNotExists bool) (*types.User, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id) existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id)
if err != nil { if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
@ -690,8 +661,8 @@ func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update *
return existingUser, nil return existingUser, nil
} }
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *User) (bool, error) { func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) {
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner { if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
newInitiatorUser := initiatorUser.Copy() newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = types.UserRoleAdmin newInitiatorUser.Role = types.UserRoleAdmin
@ -713,6 +684,7 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, transaction st
} }
if !isNil(am.idpManager) && !user.IsServiceUser { if !isNil(am.idpManager) && !user.IsServiceUser {
// TODO: Run lookupUserInCache with transaction
userData, err := am.lookupUserInCache(ctx, user.Id, accountID) userData, err := am.lookupUserInCache(ctx, user.Id, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -818,7 +790,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
users := make(map[string]userLoggedInOnce, len(accountUsers)) users := make(map[string]userLoggedInOnce, len(accountUsers))
usersFromIntegration := make([]*idp.UserData, 0) usersFromIntegration := make([]*idp.UserData, 0)
for _, user := range accountUsers { for _, user := range accountUsers {
if user.Issued == UserIssuedIntegration { if user.Issued == types.UserIssuedIntegration {
key := user.IntegrationReference.CacheKey(accountID, user.Id) key := user.IntegrationReference.CacheKey(accountID, user.Id)
info, err := am.externalCacheManager.Get(am.ctx, key) info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil { if err != nil {
@ -1059,7 +1031,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var addPeerRemovedEvents []func() var addPeerRemovedEvents []func()
var updateAccountPeers bool var updateAccountPeers bool
var targetUser *User var targetUser *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
@ -1100,9 +1072,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
} }
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return return
} }

View File

@ -45,7 +45,7 @@ const (
) )
func TestUser_CreatePAT_ForSameUser(t *testing.T) { func TestUser_CreatePAT_ForSameUser(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil { if err != nil {
t.Fatalf("Error when creating store: %s", err) t.Fatalf("Error when creating store: %s", err)
} }
@ -53,13 +53,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err = store.SaveAccount(context.Background(), account) err = s.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
t.Fatalf("Error when saving account: %s", err) t.Fatalf("Error when saving account: %s", err)
} }
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: s,
eventStore: &activity.InMemoryEventStore{}, eventStore: &activity.InMemoryEventStore{},
} }
@ -81,7 +81,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID) assert.Equal(t, pat.ID, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, tokenID) user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
if err != nil { if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err) t.Fatalf("Error when getting user by token ID: %s", err)
} }