mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-11 16:38:27 +01:00
Retrieve policy groups and posture checks once for validation
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
2a59f04540
commit
32d1b2d602
@ -521,12 +521,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po
|
||||
policy.AccountID = accountID
|
||||
}
|
||||
|
||||
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -629,15 +629,10 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
||||
}
|
||||
|
||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||
func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string {
|
||||
validPostureCheckIDs := make(map[string]struct{})
|
||||
for _, check := range postureChecks {
|
||||
validPostureCheckIDs[check.ID] = struct{}{}
|
||||
}
|
||||
|
||||
func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
|
||||
validIDs := make([]string, 0, len(postureChecksIds))
|
||||
for _, id := range postureChecksIds {
|
||||
if _, exists := validPostureCheckIDs[id]; exists {
|
||||
if _, exists := postureChecks[id]; exists {
|
||||
validIDs = append(validIDs, id)
|
||||
}
|
||||
}
|
||||
@ -646,15 +641,10 @@ func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds [
|
||||
}
|
||||
|
||||
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
|
||||
func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string {
|
||||
validGroupIDs := make(map[string]struct{})
|
||||
for _, group := range groups {
|
||||
validGroupIDs[group.ID] = struct{}{}
|
||||
}
|
||||
|
||||
func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
|
||||
validIDs := make([]string, 0, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if _, exists := validGroupIDs[id]; exists {
|
||||
if _, exists := groups[id]; exists {
|
||||
validIDs = append(validIDs, id)
|
||||
}
|
||||
}
|
||||
|
@ -1234,8 +1234,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
|
||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*nbgroup.Group)
|
||||
@ -1377,6 +1377,23 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
|
||||
return postureCheck, nil
|
||||
}
|
||||
|
||||
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
|
||||
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
|
||||
var postureChecks []*posture.Checks
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store")
|
||||
}
|
||||
|
||||
postureChecksMap := make(map[string]*posture.Checks)
|
||||
for _, postureCheck := range postureChecks {
|
||||
postureChecksMap[postureCheck.ID] = postureCheck
|
||||
}
|
||||
|
||||
return postureChecksMap, nil
|
||||
}
|
||||
|
||||
// SavePostureChecks saves a posture checks to the database.
|
||||
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
|
||||
|
@ -88,6 +88,7 @@ type Store interface {
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
||||
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user