[management] force account id on save groups update (#3850)

This commit is contained in:
Pedro Maia Costa 2025-05-23 14:42:42 +01:00 committed by GitHub
parent a0482ebc7b
commit 5bed6777d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 19 additions and 11 deletions

View File

@ -1248,7 +1248,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, newGroupsToCreate); err != nil { if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }
@ -1282,7 +1282,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return fmt.Errorf("error modifying user peers in groups: %w", err) return fmt.Errorf("error modifying user peers in groups: %w", err)
} }
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, updatedGroups); err != nil { if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, updatedGroups); err != nil {
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }

View File

@ -116,7 +116,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
return err return err
} }
return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, groupsToSave) return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
}) })
if err != nil { if err != nil {
return err return err

View File

@ -455,7 +455,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
AccountID: account.Id, AccountID: account.Id,
Peers: []string{}, Peers: []string{},
} }
err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*types.Group{groupA, groupB}) err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, account.Id, []*types.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups") require.NoError(t, err, "failed to save groups")
postureCheckA := &posture.Checks{ postureCheckA := &posture.Checks{

View File

@ -448,12 +448,20 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
} }
// SaveGroups saves the given list of groups to the database. // SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error { func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error {
if len(groups) == 0 { if len(groups) == 0 {
return nil return nil
} }
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&groups) result := s.db.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Create(&groups)
if result.Error != nil { if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
} }

View File

@ -1324,11 +1324,11 @@ func TestSqlStore_SaveGroups(t *testing.T) {
Peers: []string{"peer3", "peer4"}, Peers: []string{"peer3", "peer4"},
}, },
} }
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
require.NoError(t, err) require.NoError(t, err)
groups[1].Peers = []string{} groups[1].Peers = []string{}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
require.NoError(t, err) require.NoError(t, err)
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID) group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID)
@ -3240,7 +3240,7 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
}) })
} }
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groupsToSave) err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave)
require.NoError(t, err) require.NoError(t, err)
accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)

View File

@ -98,7 +98,7 @@ type Store interface {
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error

View File

@ -676,7 +676,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err) return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err)
} }
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, updatedGroups); err != nil { if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, update.AccountID, updatedGroups); err != nil {
return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err) return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err)
} }
} }