mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-31 18:39:31 +01:00
Fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
a72a331128
commit
7124cf5c94
@ -1043,7 +1043,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
|||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
|
|
||||||
newUser := NewRegularUser(claims.UserId)
|
newUser := types.NewRegularUser(claims.UserId)
|
||||||
newUser.AccountID = domainAccountID
|
newUser.AccountID = domainAccountID
|
||||||
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
|
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1121,14 +1121,14 @@ func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, toke
|
|||||||
}
|
}
|
||||||
|
|
||||||
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
|
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
|
||||||
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
|
||||||
if len(token) != PATLength {
|
if len(token) != types.PATLength {
|
||||||
return nil, nil, fmt.Errorf("token has incorrect length")
|
return nil, nil, fmt.Errorf("token has incorrect length")
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix := token[:len(types.PATPrefix)]
|
prefix := token[:len(types.PATPrefix)]
|
||||||
if prefix != types.PATPrefix {
|
if prefix != types.PATPrefix {
|
||||||
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
|
return nil, nil, fmt.Errorf("token has wrong prefix")
|
||||||
}
|
}
|
||||||
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
|
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
|
||||||
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
|
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
|
||||||
@ -1146,10 +1146,10 @@ func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token
|
|||||||
hashedToken := sha256.Sum256([]byte(token))
|
hashedToken := sha256.Sum256([]byte(token))
|
||||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||||
|
|
||||||
var user *User
|
var user *types.User
|
||||||
var pat *PersonalAccessToken
|
var pat *types.PersonalAccessToken
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
return fmt.Errorf("error getting user peers: %w", err)
|
return fmt.Errorf("error getting user peers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
|
updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,7 @@ var testAccount = &types.Account{
|
|||||||
userID: {
|
userID: {
|
||||||
Id: userID,
|
Id: userID,
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
PATs: map[string]*server.PersonalAccessToken{
|
PATs: map[string]*types.PersonalAccessToken{
|
||||||
tokenID: {
|
tokenID: {
|
||||||
ID: tokenID,
|
ID: tokenID,
|
||||||
Name: "My first token",
|
Name: "My first token",
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gorm.io/driver/mysql"
|
"gorm.io/driver/mysql"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
@ -567,8 +568,8 @@ func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
|
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
|
||||||
var users []*User
|
var users []*types.User
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
@ -899,7 +900,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
|
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
|
||||||
var createdBy string
|
var createdBy string
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||||
Select("created_by").First(&createdBy, idQueryCondition, accountID)
|
Select("created_by").First(&createdBy, idQueryCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
@ -2063,8 +2064,8 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
|
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
|
||||||
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) {
|
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
|
||||||
var pat PersonalAccessToken
|
var pat types.PersonalAccessToken
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
@ -2078,8 +2079,8 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPATByID retrieves a personal access token by its ID and user ID.
|
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||||
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) {
|
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
|
||||||
var pat PersonalAccessToken
|
var pat types.PersonalAccessToken
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
@ -2094,8 +2095,8 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserPATs retrieves personal access tokens for a user.
|
// GetUserPATs retrieves personal access tokens for a user.
|
||||||
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) {
|
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
|
||||||
var pats []*PersonalAccessToken
|
var pats []*types.PersonalAccessToken
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
|
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
|
||||||
@ -2107,8 +2108,8 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength
|
|||||||
|
|
||||||
// MarkPATUsed marks a personal access token as used.
|
// MarkPATUsed marks a personal access token as used.
|
||||||
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
|
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
|
||||||
patCopy := PersonalAccessToken{
|
patCopy := types.PersonalAccessToken{
|
||||||
LastUsed: time.Now().UTC(),
|
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldsToUpdate := []string{"last_used"}
|
fieldsToUpdate := []string{"last_used"}
|
||||||
@ -2127,7 +2128,7 @@ func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePAT saves a personal access token to the database.
|
// SavePAT saves a personal access token to the database.
|
||||||
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
|
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error {
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)
|
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)
|
||||||
|
@ -2916,16 +2916,16 @@ func TestSqlStore_SaveUser(t *testing.T) {
|
|||||||
|
|
||||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
user := &User{
|
user := &types.User{
|
||||||
Id: "user-id",
|
Id: "user-id",
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Role: UserRoleAdmin,
|
Role: types.UserRoleAdmin,
|
||||||
IsServiceUser: false,
|
IsServiceUser: false,
|
||||||
AutoGroups: []string{"groupA", "groupB"},
|
AutoGroups: []string{"groupA", "groupB"},
|
||||||
Blocked: false,
|
Blocked: false,
|
||||||
LastLogin: time.Now().UTC(),
|
LastLogin: util.ToPtr(time.Now().UTC()),
|
||||||
CreatedAt: time.Now().UTC().Add(-time.Hour),
|
CreatedAt: time.Now().UTC().Add(-time.Hour),
|
||||||
Issued: UserIssuedIntegration,
|
Issued: types.UserIssuedIntegration,
|
||||||
}
|
}
|
||||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, user)
|
err = store.SaveUser(context.Background(), LockingStrengthUpdate, user)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -2936,7 +2936,7 @@ func TestSqlStore_SaveUser(t *testing.T) {
|
|||||||
require.Equal(t, user.AccountID, saveUser.AccountID)
|
require.Equal(t, user.AccountID, saveUser.AccountID)
|
||||||
require.Equal(t, user.Role, saveUser.Role)
|
require.Equal(t, user.Role, saveUser.Role)
|
||||||
require.Equal(t, user.AutoGroups, saveUser.AutoGroups)
|
require.Equal(t, user.AutoGroups, saveUser.AutoGroups)
|
||||||
require.WithinDurationf(t, user.LastLogin, saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal")
|
require.WithinDurationf(t, user.GetLastLogin(), saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal")
|
||||||
require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
|
require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
|
||||||
require.Equal(t, user.Issued, saveUser.Issued)
|
require.Equal(t, user.Issued, saveUser.Issued)
|
||||||
require.Equal(t, user.Blocked, saveUser.Blocked)
|
require.Equal(t, user.Blocked, saveUser.Blocked)
|
||||||
@ -2954,7 +2954,7 @@ func TestSqlStore_SaveUsers(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, accountUsers, 2)
|
require.Len(t, accountUsers, 2)
|
||||||
|
|
||||||
users := []*User{
|
users := []*types.User{
|
||||||
{
|
{
|
||||||
Id: "user-1",
|
Id: "user-1",
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
@ -3087,15 +3087,15 @@ func TestSqlStore_SavePAT(t *testing.T) {
|
|||||||
|
|
||||||
userID := "edafee4e-63fb-11ec-90d6-0242ac120003"
|
userID := "edafee4e-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
pat := &PersonalAccessToken{
|
pat := &types.PersonalAccessToken{
|
||||||
ID: "pat-id",
|
ID: "pat-id",
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
Name: "token",
|
Name: "token",
|
||||||
HashedToken: "SoMeHaShEdToKeN",
|
HashedToken: "SoMeHaShEdToKeN",
|
||||||
ExpirationDate: time.Now().UTC().Add(12 * time.Hour),
|
ExpirationDate: util.ToPtr(time.Now().UTC().Add(12 * time.Hour)),
|
||||||
CreatedBy: userID,
|
CreatedBy: userID,
|
||||||
CreatedAt: time.Now().UTC().Add(time.Hour),
|
CreatedAt: time.Now().UTC().Add(time.Hour),
|
||||||
LastUsed: time.Now().UTC().Add(-15 * time.Minute),
|
LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)),
|
||||||
}
|
}
|
||||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat)
|
err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -3106,9 +3106,9 @@ func TestSqlStore_SavePAT(t *testing.T) {
|
|||||||
require.Equal(t, pat.UserID, savePAT.UserID)
|
require.Equal(t, pat.UserID, savePAT.UserID)
|
||||||
require.Equal(t, pat.HashedToken, savePAT.HashedToken)
|
require.Equal(t, pat.HashedToken, savePAT.HashedToken)
|
||||||
require.Equal(t, pat.CreatedBy, savePAT.CreatedBy)
|
require.Equal(t, pat.CreatedBy, savePAT.CreatedBy)
|
||||||
require.WithinDurationf(t, pat.ExpirationDate, savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal")
|
require.WithinDurationf(t, pat.GetExpirationDate(), savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal")
|
||||||
require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
|
require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
|
||||||
require.WithinDurationf(t, pat.LastUsed, savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal")
|
require.WithinDurationf(t, pat.GetLastUsed(), savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlStore_DeletePAT(t *testing.T) {
|
func TestSqlStore_DeletePAT(t *testing.T) {
|
||||||
|
@ -139,7 +139,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
|||||||
newUser := &types.User{
|
newUser := &types.User{
|
||||||
Id: idpUser.ID,
|
Id: idpUser.ID,
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Role: StrRoleToUserRole(invite.Role),
|
Role: types.StrRoleToUserRole(invite.Role),
|
||||||
AutoGroups: invite.AutoGroups,
|
AutoGroups: invite.AutoGroups,
|
||||||
Issued: invite.Issued,
|
Issued: invite.Issued,
|
||||||
IntegrationReference: invite.IntegrationReference,
|
IntegrationReference: invite.IntegrationReference,
|
||||||
@ -489,20 +489,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
var updateAccountPeers bool
|
var updateAccountPeers bool
|
||||||
var peersToExpire []*nbpeer.Peer
|
var peersToExpire []*nbpeer.Peer
|
||||||
var addUserEvents []func()
|
var addUserEvents []func()
|
||||||
var usersToSave = make([]*User, 0, len(updates))
|
var usersToSave = make([]*types.User, 0, len(updates))
|
||||||
var updatedUsersInfo = make([]*UserInfo, 0, len(updates))
|
var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates))
|
||||||
|
|
||||||
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting account groups: %w", err)
|
return nil, fmt.Errorf("error getting account groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
groupsMap := make(map[string]*nbgroup.Group, len(groups))
|
groupsMap := make(map[string]*types.Group, len(groups))
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
groupsMap[group.ID] = group
|
groupsMap[group.ID] = group
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
for _, update := range updates {
|
for _, update := range updates {
|
||||||
if update == nil {
|
if update == nil {
|
||||||
return status.Errorf(status.InvalidArgument, "provided user update is nil")
|
return status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||||
@ -550,14 +550,14 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
|||||||
if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||||
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
return nil, fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.UpdateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return updatedUsersInfo, nil
|
return updatedUsersInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
|
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
|
||||||
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, groupsMap map[string]*types.Group, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() {
|
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() {
|
||||||
var eventsToStore []func()
|
var eventsToStore []func()
|
||||||
|
|
||||||
if oldUser.IsBlocked() != newUser.IsBlocked() {
|
if oldUser.IsBlocked() != newUser.IsBlocked() {
|
||||||
@ -586,36 +586,7 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, gr
|
|||||||
return eventsToStore
|
return eventsToStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
|
func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction store.Store, groupsMap map[string]*types.Group,
|
||||||
var eventsToStore []func()
|
|
||||||
if newUser.AutoGroups != nil {
|
|
||||||
removedGroups := utildifference(oldUser.AutoGroups, newUser.AutoGroups)
|
|
||||||
addedGroups := utildifference(newUser.AutoGroups, oldUser.AutoGroups)
|
|
||||||
for _, g := range removedGroups {
|
|
||||||
group, ok := groupsMap[g]
|
|
||||||
if ok {
|
|
||||||
eventsToStore = append(eventsToStore, func() {
|
|
||||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}
|
|
||||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta)
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, g := range addedGroups {
|
|
||||||
group, ok := groupsMap[g]
|
|
||||||
if ok {
|
|
||||||
eventsToStore = append(eventsToStore, func() {
|
|
||||||
meta := map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}
|
|
||||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, meta)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return eventsToStore
|
|
||||||
}
|
|
||||||
|
|
||||||
func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction Store, groupsMap map[string]*types.Group,
|
|
||||||
initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
|
initiatorUser, update *types.User, addIfNotExists bool, settings *types.Settings) (bool, *types.User, []*nbpeer.Peer, []func(), error) {
|
||||||
|
|
||||||
if update == nil {
|
if update == nil {
|
||||||
@ -658,8 +629,8 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti
|
|||||||
}
|
}
|
||||||
|
|
||||||
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
|
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
|
||||||
removedGroups := util.difference(oldUser.AutoGroups, update.AutoGroups)
|
removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
|
||||||
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups)
|
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err)
|
return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||||
}
|
}
|
||||||
@ -670,13 +641,13 @@ func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transacti
|
|||||||
}
|
}
|
||||||
|
|
||||||
updateAccountPeers := len(userPeers) > 0
|
updateAccountPeers := len(userPeers) > 0
|
||||||
userEventsToAdd := am.prepareUserUpdateEvents(ctx, groupsMap, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole)
|
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole)
|
||||||
|
|
||||||
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
|
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
|
||||||
func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update *User, addIfNotExists bool) (*User, error) {
|
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, update *types.User, addIfNotExists bool) (*types.User, error) {
|
||||||
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id)
|
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
@ -690,8 +661,8 @@ func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update *
|
|||||||
return existingUser, nil
|
return existingUser, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *User) (bool, error) {
|
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) {
|
||||||
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
|
if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
|
||||||
newInitiatorUser := initiatorUser.Copy()
|
newInitiatorUser := initiatorUser.Copy()
|
||||||
newInitiatorUser.Role = types.UserRoleAdmin
|
newInitiatorUser.Role = types.UserRoleAdmin
|
||||||
|
|
||||||
@ -713,6 +684,7 @@ func (am *DefaultAccountManager) getUserInfo(ctx context.Context, transaction st
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !isNil(am.idpManager) && !user.IsServiceUser {
|
if !isNil(am.idpManager) && !user.IsServiceUser {
|
||||||
|
// TODO: Run lookupUserInCache with transaction
|
||||||
userData, err := am.lookupUserInCache(ctx, user.Id, accountID)
|
userData, err := am.lookupUserInCache(ctx, user.Id, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -818,7 +790,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
|||||||
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||||
usersFromIntegration := make([]*idp.UserData, 0)
|
usersFromIntegration := make([]*idp.UserData, 0)
|
||||||
for _, user := range accountUsers {
|
for _, user := range accountUsers {
|
||||||
if user.Issued == UserIssuedIntegration {
|
if user.Issued == types.UserIssuedIntegration {
|
||||||
key := user.IntegrationReference.CacheKey(accountID, user.Id)
|
key := user.IntegrationReference.CacheKey(accountID, user.Id)
|
||||||
info, err := am.externalCacheManager.Get(am.ctx, key)
|
info, err := am.externalCacheManager.Get(am.ctx, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1059,7 +1031,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
|||||||
|
|
||||||
var addPeerRemovedEvents []func()
|
var addPeerRemovedEvents []func()
|
||||||
var updateAccountPeers bool
|
var updateAccountPeers bool
|
||||||
var targetUser *User
|
var targetUser *types.User
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
|
||||||
@ -1100,9 +1072,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
||||||
func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd,
|
func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
|
||||||
groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
|
|
||||||
|
|
||||||
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when creating store: %s", err)
|
t.Fatalf("Error when creating store: %s", err)
|
||||||
}
|
}
|
||||||
@ -53,13 +53,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
|||||||
|
|
||||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||||
|
|
||||||
err = store.SaveAccount(context.Background(), account)
|
err = s.SaveAccount(context.Background(), account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when saving account: %s", err)
|
t.Fatalf("Error when saving account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
am := DefaultAccountManager{
|
am := DefaultAccountManager{
|
||||||
Store: store,
|
Store: s,
|
||||||
eventStore: &activity.InMemoryEventStore{},
|
eventStore: &activity.InMemoryEventStore{},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, pat.ID, tokenID)
|
assert.Equal(t, pat.ID, tokenID)
|
||||||
|
|
||||||
user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, tokenID)
|
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when getting user by token ID: %s", err)
|
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user