mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-28 21:51:40 +02:00
Retrieve policy groups and posture checks once for validation
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
2a59f04540
commit
32d1b2d602
@ -521,12 +521,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po
|
|||||||
policy.AccountID = accountID
|
policy.AccountID = accountID
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -629,15 +629,10 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||||
func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string {
|
func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
|
||||||
validPostureCheckIDs := make(map[string]struct{})
|
|
||||||
for _, check := range postureChecks {
|
|
||||||
validPostureCheckIDs[check.ID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
validIDs := make([]string, 0, len(postureChecksIds))
|
validIDs := make([]string, 0, len(postureChecksIds))
|
||||||
for _, id := range postureChecksIds {
|
for _, id := range postureChecksIds {
|
||||||
if _, exists := validPostureCheckIDs[id]; exists {
|
if _, exists := postureChecks[id]; exists {
|
||||||
validIDs = append(validIDs, id)
|
validIDs = append(validIDs, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -646,15 +641,10 @@ func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds [
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
|
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
|
||||||
func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string {
|
func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
|
||||||
validGroupIDs := make(map[string]struct{})
|
|
||||||
for _, group := range groups {
|
|
||||||
validGroupIDs[group.ID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
validIDs := make([]string, 0, len(groupIDs))
|
validIDs := make([]string, 0, len(groupIDs))
|
||||||
for _, id := range groupIDs {
|
for _, id := range groupIDs {
|
||||||
if _, exists := validGroupIDs[id]; exists {
|
if _, exists := groups[id]; exists {
|
||||||
validIDs = append(validIDs, id)
|
validIDs = append(validIDs, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1234,8 +1234,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
|||||||
var groups []*nbgroup.Group
|
var groups []*nbgroup.Group
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
|
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
|
||||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
|
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
groupsMap := make(map[string]*nbgroup.Group)
|
groupsMap := make(map[string]*nbgroup.Group)
|
||||||
@ -1377,6 +1377,23 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
|
|||||||
return postureCheck, nil
|
return postureCheck, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
|
||||||
|
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
|
||||||
|
var postureChecks []*posture.Checks
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
postureChecksMap := make(map[string]*posture.Checks)
|
||||||
|
for _, postureCheck := range postureChecks {
|
||||||
|
postureChecksMap[postureCheck.ID] = postureCheck
|
||||||
|
}
|
||||||
|
|
||||||
|
return postureChecksMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SavePostureChecks saves a posture checks to the database.
|
// SavePostureChecks saves a posture checks to the database.
|
||||||
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
|
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
|
||||||
|
@ -88,6 +88,7 @@ type Store interface {
|
|||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
||||||
|
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
|
||||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user