Retrieve policy groups and posture checks once for validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-12 18:53:10 +03:00
parent 2a59f04540
commit 32d1b2d602
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
3 changed files with 26 additions and 18 deletions

View File

@ -521,12 +521,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po
policy.AccountID = accountID policy.AccountID = accountID
} }
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
if err != nil { if err != nil {
return err return err
} }
postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
if err != nil { if err != nil {
return err return err
} }
@ -629,15 +629,10 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
} }
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string { func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
validPostureCheckIDs := make(map[string]struct{})
for _, check := range postureChecks {
validPostureCheckIDs[check.ID] = struct{}{}
}
validIDs := make([]string, 0, len(postureChecksIds)) validIDs := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds { for _, id := range postureChecksIds {
if _, exists := validPostureCheckIDs[id]; exists { if _, exists := postureChecks[id]; exists {
validIDs = append(validIDs, id) validIDs = append(validIDs, id)
} }
} }
@ -646,15 +641,10 @@ func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds [
} }
// getValidGroupIDs filters and returns only the valid group IDs from the provided list. // getValidGroupIDs filters and returns only the valid group IDs from the provided list.
func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string { func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
validGroupIDs := make(map[string]struct{})
for _, group := range groups {
validGroupIDs[group.ID] = struct{}{}
}
validIDs := make([]string, 0, len(groupIDs)) validIDs := make([]string, 0, len(groupIDs))
for _, id := range groupIDs { for _, id := range groupIDs {
if _, exists := validGroupIDs[id]; exists { if _, exists := groups[id]; exists {
validIDs = append(validIDs, id) validIDs = append(validIDs, id)
} }
} }

View File

@ -1234,8 +1234,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
var groups []*nbgroup.Group var groups []*nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
} }
groupsMap := make(map[string]*nbgroup.Group) groupsMap := make(map[string]*nbgroup.Group)
@ -1377,6 +1377,23 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
return postureCheck, nil return postureCheck, nil
} }
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
var postureChecks []*posture.Checks
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store")
}
postureChecksMap := make(map[string]*posture.Checks)
for _, postureCheck := range postureChecks {
postureChecksMap[postureCheck.ID] = postureCheck
}
return postureChecksMap, nil
}
// SavePostureChecks saves a posture checks to the database. // SavePostureChecks saves a posture checks to the database.
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)

View File

@ -88,6 +88,7 @@ type Store interface {
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error