mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-14 01:08:46 +02:00
[management] refactor to use account object instead of separate db calls for peer update (#2957)
This commit is contained in:
@ -2,14 +2,16 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/rs/xid"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
@ -149,38 +151,21 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) {
|
||||
peerPostureChecks := make(map[string]*posture.Checks)
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
if len(account.PostureChecks) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(postureChecks) == 0 {
|
||||
return nil
|
||||
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return maps.Values(peerPostureChecks), nil
|
||||
@ -241,8 +226,8 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str
|
||||
}
|
||||
|
||||
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
||||
func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||
isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy)
|
||||
func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -252,9 +237,9 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p
|
||||
}
|
||||
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
|
||||
if err != nil {
|
||||
return err
|
||||
postureCheck := account.getPostureChecks(sourcePostureCheckID)
|
||||
if postureCheck == nil {
|
||||
return errors.New("failed to add policy posture checks: posture checks not found")
|
||||
}
|
||||
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
||||
}
|
||||
@ -263,16 +248,16 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p
|
||||
}
|
||||
|
||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||
func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) {
|
||||
func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
|
||||
group := account.GetGroup(sourceGroup)
|
||||
if group == nil {
|
||||
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
|
Reference in New Issue
Block a user