Use policy expanded peers map from src/dest groups

Pre expand the peers from policy rules source and destination groups
to avoid extra allocation when calculating network map
This commit is contained in:
Maycon Santos 2024-08-07 15:21:02 +02:00
parent bcce1bf184
commit ec4469f43d
5 changed files with 99 additions and 56 deletions

View File

@ -413,6 +413,7 @@ func (a *Account) GetPeerNetworkMap(
peersCustomZone nbdns.CustomZone,
validatedPeersMap map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
expandedPolicies policyRuleExpandedPeers,
) *NetworkMap {
start := time.Now()
@ -429,7 +430,7 @@ func (a *Account) GetPeerNetworkMap(
}
}
aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap)
aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap, expandedPolicies)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer

View File

@ -412,7 +412,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
policyExpandedPeers := account.getPolicyExpandedPeers()
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil, policyExpandedPeers)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}

View File

@ -87,8 +87,9 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
}
// fetch all the peers that have access to the user's peers
policyExpandedPeers := account.getPolicyExpandedPeers()
for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap, policyExpandedPeers)
for _, p := range aclPeers {
peersMap[p.ID] = p
}
@ -324,7 +325,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil
policyExpandedPeers := account.getPolicyExpandedPeers()
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil, policyExpandedPeers), nil
}
// GetPeerNetwork returns the Network for a given peer
@ -538,7 +540,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
postureChecks := am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
policyExpandedPeers := account.getPolicyExpandedPeers()
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers)
return newPeer, networkMap, postureChecks, nil
}
@ -595,7 +598,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
postureChecks = am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
policyExpandedPeers := account.getPolicyExpandedPeers()
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers), postureChecks, nil
}
// LoginPeer logs in or registers a peer.
@ -743,7 +747,8 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
postureChecks = am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
policyExpandedPeers := account.getPolicyExpandedPeers()
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers), postureChecks, nil
}
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
@ -896,8 +901,9 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return nil, err
}
policyExpandedPeers := account.getPolicyExpandedPeers()
for _, p := range userPeers {
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap, policyExpandedPeers)
for _, aclPeer := range aclPeers {
if aclPeer.ID == peerID {
return peer, nil
@ -939,7 +945,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
dnsCache := &DNSConfigCache{}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
expandedPolicies := account.getPolicyExpandedPeers()
for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
@ -953,7 +959,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer func() { <-semaphore }()
postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), expandedPolicies)
update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer)

View File

@ -212,20 +212,20 @@ type FirewallRule struct {
// getPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}, expandedPolicies policyRuleExpandedPeers) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
for n, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, expandedPolicies[policy.ID][n].sourcePeers, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, expandedPolicies[policy.ID][n].destinationPeers, peerID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
@ -490,38 +490,26 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
//
// Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
func (a *Account) getAllPeersFromGroups(ctx context.Context, peerMap peerMap, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups {
group, ok := a.Groups[g]
if !ok {
filteredPeers := make([]*nbpeer.Peer, 0, len(peerMap))
for _, peer := range peerMap {
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
for _, p := range group.Peers {
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
}
return filteredPeers, peerInGroups
}
@ -560,3 +548,41 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
}
return nil
}
type expandedRuleGroups struct {
sourcePeers peerMap
destinationPeers peerMap
}
type peerMap map[string]*nbpeer.Peer
type policyRuleExpandedPeers map[string]map[int]expandedRuleGroups
func (a *Account) getPolicyExpandedPeers() policyRuleExpandedPeers {
policyMap := make(policyRuleExpandedPeers)
for _, policy := range a.Policies {
if !policy.Enabled {
continue
}
ruleMap := make(map[int]expandedRuleGroups)
policyMap[policy.ID] = ruleMap
for ruleID, rule := range policy.Rules {
policyMap[policy.ID][ruleID] = expandedRuleGroups{
sourcePeers: make(peerMap),
destinationPeers: make(peerMap),
}
a.processGroups(rule.Sources, policyMap[policy.ID][ruleID].sourcePeers)
a.processGroups(rule.Destinations, policyMap[policy.ID][ruleID].destinationPeers)
}
}
return policyMap
}
func (a *Account) processGroups(groupIDs []string, peerMap peerMap) {
for _, pid := range groupIDs {
p, ok := a.Peers[pid]
if ok {
peerMap[pid] = p
}
}
}

View File

@ -143,15 +143,17 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}
t.Run("check that all peers get map", func(t *testing.T) {
policyExpandedPeers := account.getPolicyExpandedPeers()
for _, p := range account.Peers {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers, policyExpandedPeers)
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
}
})
t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers)
policyExpandedPeers := account.getPolicyExpandedPeers()
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers, policyExpandedPeers)
assert.Len(t, peers, 7)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
@ -387,7 +389,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}
t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
policyExpandedPeers := account.getPolicyExpandedPeers()
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*FirewallRule{
@ -415,7 +418,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
policyExpandedPeers := account.getPolicyExpandedPeers()
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*FirewallRule{
@ -445,7 +449,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
policyExpandedPeers := account.getPolicyExpandedPeers()
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*FirewallRule{
@ -466,7 +471,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
policyExpandedPeers := account.getPolicyExpandedPeers()
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*FirewallRule{
@ -661,9 +667,10 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
approvedPeers[p] = struct{}{}
}
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
policyExpandedPeers := account.getPolicyExpandedPeers()
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@ -673,7 +680,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*FirewallRule{
@ -689,7 +696,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@ -699,7 +706,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@ -711,22 +718,23 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with modified group peer list", func(t *testing.T) {
// Removing peerB as the part of destination group Swarm
account.Groups["GroupSwarm"].Peers = []string{"peerA", "peerD", "peerE", "peerG", "peerH"}
policyExpandedPeers := account.getPolicyExpandedPeers()
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@ -738,17 +746,18 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// Removing peerF as the part of source group All
account.Groups["GroupAll"].Peers = []string{"peerB", "peerA", "peerD", "peerC", "peerG", "peerH"}
policyExpandedPeers = account.getPolicyExpandedPeers()
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 5)
// assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"])