mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-03 00:25:37 +02:00
[management] refactor to use account object instead of separate db calls for peer update (#2957)
This commit is contained in:
parent
9203690033
commit
00c3b67182
@ -617,7 +617,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID)
|
postureChecks, err := am.getPeerPostureChecks(account, newPeer.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@ -707,7 +707,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
|
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@ -885,7 +885,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
|
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@ -1042,7 +1042,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer func() { <-semaphore }()
|
defer func() { <-semaphore }()
|
||||||
|
|
||||||
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
|
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
|
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
|
||||||
return
|
return
|
||||||
|
@ -833,19 +833,20 @@ func BenchmarkGetPeers(b *testing.B) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkUpdateAccountPeers(b *testing.B) {
|
func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||||
benchCases := []struct {
|
benchCases := []struct {
|
||||||
name string
|
name string
|
||||||
peers int
|
peers int
|
||||||
groups int
|
groups int
|
||||||
|
minMsPerOp float64
|
||||||
|
maxMsPerOp float64
|
||||||
}{
|
}{
|
||||||
{"Small", 50, 5},
|
{"Small", 50, 5, 90, 120},
|
||||||
{"Medium", 500, 10},
|
{"Medium", 500, 100, 110, 140},
|
||||||
{"Large", 5000, 20},
|
{"Large", 5000, 200, 800, 1300},
|
||||||
{"Small single", 50, 1},
|
{"Small single", 50, 10, 90, 120},
|
||||||
{"Medium single", 500, 1},
|
{"Medium single", 500, 10, 110, 170},
|
||||||
{"Large 5", 5000, 5},
|
{"Large 5", 5000, 15, 1300, 1800},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
@ -881,8 +882,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
|
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||||
b.ReportMetric(0, "ns/op")
|
b.ReportMetric(msPerOp, "ms/op")
|
||||||
|
|
||||||
|
if msPerOp < bc.minMsPerOp {
|
||||||
|
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp > bc.maxMsPerOp {
|
||||||
|
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,14 +2,16 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/rs/xid"
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||||
@ -149,38 +151,21 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||||
func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
|
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) {
|
||||||
peerPostureChecks := make(map[string]*posture.Checks)
|
peerPostureChecks := make(map[string]*posture.Checks)
|
||||||
|
|
||||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
if len(account.PostureChecks) == 0 {
|
||||||
postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
return nil, nil
|
||||||
if err != nil {
|
}
|
||||||
return err
|
|
||||||
|
for _, policy := range account.Policies {
|
||||||
|
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(postureChecks) == 0 {
|
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, policy := range policies {
|
|
||||||
if !policy.Enabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return maps.Values(peerPostureChecks), nil
|
return maps.Values(peerPostureChecks), nil
|
||||||
@ -241,8 +226,8 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
||||||
func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
|
func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||||
isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy)
|
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -252,9 +237,9 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||||
postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
|
postureCheck := account.getPostureChecks(sourcePostureCheckID)
|
||||||
if err != nil {
|
if postureCheck == nil {
|
||||||
return err
|
return errors.New("failed to add policy posture checks: posture checks not found")
|
||||||
}
|
}
|
||||||
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
||||||
}
|
}
|
||||||
@ -263,16 +248,16 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||||
func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) {
|
func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) {
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if !rule.Enabled {
|
if !rule.Enabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sourceGroup := range rule.Sources {
|
for _, sourceGroup := range rule.Sources {
|
||||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
|
group := account.GetGroup(sourceGroup)
|
||||||
if err != nil {
|
if group == nil {
|
||||||
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
|
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(group.Peers, peerID) {
|
if slices.Contains(group.Peers, peerID) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user