From 5bed6777d568f1e0e96a4c3dbd290e175502eb81 Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Fri, 23 May 2025 14:42:42 +0100 Subject: [PATCH] [management] force account id on save groups update (#3850) --- management/server/account.go | 4 ++-- management/server/group.go | 2 +- management/server/posture_checks_test.go | 2 +- management/server/store/sql_store.go | 12 ++++++++++-- management/server/store/sql_store_test.go | 6 +++--- management/server/store/store.go | 2 +- management/server/user.go | 2 +- 7 files changed, 19 insertions(+), 11 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 6dc449c1e..033ec5fa1 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1248,7 +1248,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return nil } - if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, newGroupsToCreate); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil { return fmt.Errorf("error saving groups: %w", err) } @@ -1282,7 +1282,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return fmt.Errorf("error modifying user peers in groups: %w", err) } - if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, updatedGroups); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, updatedGroups); err != nil { return fmt.Errorf("error saving groups: %w", err) } diff --git a/management/server/group.go b/management/server/group.go index 87d649228..c26a0cfc1 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -116,7 +116,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return err } - return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, groupsToSave) + return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) }) if err != nil { return err diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 232955f7d..8bd2fab66 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -455,7 +455,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { AccountID: account.Id, Peers: []string{}, } - err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*types.Group{groupA, groupB}) + err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, account.Id, []*types.Group{groupA, groupB}) require.NoError(t, err, "failed to save groups") postureCheckA := &posture.Checks{ diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index eb194ca9b..6c3104ef0 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -448,12 +448,20 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u } // SaveGroups saves the given list of groups to the database. -func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error { +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { if len(groups) == 0 { return nil } - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&groups) + result := s.db. + Clauses( + clause.Locking{Strength: string(lockStrength)}, + clause.OnConflict{ + Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, + UpdateAll: true, + }, + ). + Create(&groups) if result.Error != nil { return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) } diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 8e99b34e1..2c1f5f8e6 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -1324,11 +1324,11 @@ func TestSqlStore_SaveGroups(t *testing.T) { Peers: []string{"peer3", "peer4"}, }, } - err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups) require.NoError(t, err) groups[1].Peers = []string{} - err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups) require.NoError(t, err) group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID) @@ -3240,7 +3240,7 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { }) } - err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groupsToSave) + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave) require.NoError(t, err) accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) diff --git a/management/server/store/store.go b/management/server/store/store.go index 3d529ceb5..b3c2fceff 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -98,7 +98,7 @@ type Store interface { GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) - SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error + SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error diff --git a/management/server/user.go b/management/server/user.go index 44ad3b68f..2c762a8eb 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -676,7 +676,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err) } - if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, updatedGroups); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, update.AccountID, updatedGroups); err != nil { return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err) } }