diff --git a/management/server/account.go b/management/server/account.go index e99e0e7f3..f89e84ddc 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -413,6 +413,7 @@ func (a *Account) GetPeerNetworkMap( peersCustomZone nbdns.CustomZone, validatedPeersMap map[string]struct{}, metrics *telemetry.AccountManagerMetrics, + expandedPolicies policyRuleExpandedPeers, ) *NetworkMap { start := time.Now() @@ -429,7 +430,7 @@ func (a *Account) GetPeerNetworkMap( } } - aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap) + aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap, expandedPolicies) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer diff --git a/management/server/account_test.go b/management/server/account_test.go index 03b5fa83e..5972235dd 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -412,7 +412,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) + policyExpandedPeers := account.getPolicyExpandedPeers() + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil, policyExpandedPeers) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } diff --git a/management/server/peer.go b/management/server/peer.go index 7afe6ee0d..0fcbe2a21 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -87,8 +87,9 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID } // fetch all the peers that have access to the user's peers + policyExpandedPeers := account.getPolicyExpandedPeers() for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) + aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap, policyExpandedPeers) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -324,7 +325,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin return nil, err } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil + policyExpandedPeers := account.getPolicyExpandedPeers() + return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil, policyExpandedPeers), nil } // GetPeerNetwork returns the Network for a given peer @@ -538,7 +540,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s postureChecks := am.getPeerPostureChecks(account, peer) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) + policyExpandedPeers := account.getPolicyExpandedPeers() + networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers) return newPeer, networkMap, postureChecks, nil } @@ -595,7 +598,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac postureChecks = am.getPeerPostureChecks(account, peer) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil + policyExpandedPeers := account.getPolicyExpandedPeers() + return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers), postureChecks, nil } // LoginPeer logs in or registers a peer. @@ -743,7 +747,8 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is postureChecks = am.getPeerPostureChecks(account, peer) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil + policyExpandedPeers := account.getPolicyExpandedPeers() + return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers), postureChecks, nil } func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error { @@ -896,8 +901,9 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, err } + policyExpandedPeers := account.getPolicyExpandedPeers() for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) + aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap, policyExpandedPeers) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -939,7 +945,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account dnsCache := &DNSConfigCache{} customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - + expandedPolicies := account.getPolicyExpandedPeers() for _, peer := range peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) @@ -953,7 +959,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account defer func() { <-semaphore }() postureChecks := am.getPeerPostureChecks(account, p) - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), expandedPolicies) update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) }(peer) diff --git a/management/server/policy.go b/management/server/policy.go index aaf9b6e72..1a0223703 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -212,20 +212,20 @@ type FirewallRule struct { // getPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}, expandedPolicies policyRuleExpandedPeers) ([]*nbpeer.Peer, []*FirewallRule) { generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) for _, policy := range a.Policies { if !policy.Enabled { continue } - for _, rule := range policy.Rules { + for n, rule := range policy.Rules { if !rule.Enabled { continue } - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) + sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, expandedPolicies[policy.ID][n].sourcePeers, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, expandedPolicies[policy.ID][n].destinationPeers, peerID, nil, validatedPeersMap) if rule.Bidirectional { if peerInSources { @@ -490,38 +490,26 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { // // Important: Posture checks are applicable only to source group peers, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { +func (a *Account) getAllPeersFromGroups(ctx context.Context, peerMap peerMap, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false - filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) - for _, g := range groups { - group, ok := a.Groups[g] - if !ok { + filteredPeers := make([]*nbpeer.Peer, 0, len(peerMap)) + for _, peer := range peerMap { + + if _, ok := validatedPeersMap[peer.ID]; !ok { continue } - for _, p := range group.Peers { - peer, ok := a.Peers[p] - if !ok || peer == nil { - continue - } - - // validate the peer based on policy posture checks applied - isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) - if !isValid { - continue - } - - if _, ok := validatedPeersMap[peer.ID]; !ok { - continue - } - - if peer.ID == peerID { - peerInGroups = true - continue - } - - filteredPeers = append(filteredPeers, peer) + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid { + continue } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peer) } return filteredPeers, peerInGroups } @@ -560,3 +548,41 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { } return nil } + +type expandedRuleGroups struct { + sourcePeers peerMap + destinationPeers peerMap +} + +type peerMap map[string]*nbpeer.Peer + +type policyRuleExpandedPeers map[string]map[int]expandedRuleGroups + +func (a *Account) getPolicyExpandedPeers() policyRuleExpandedPeers { + policyMap := make(policyRuleExpandedPeers) + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + ruleMap := make(map[int]expandedRuleGroups) + policyMap[policy.ID] = ruleMap + for ruleID, rule := range policy.Rules { + policyMap[policy.ID][ruleID] = expandedRuleGroups{ + sourcePeers: make(peerMap), + destinationPeers: make(peerMap), + } + a.processGroups(rule.Sources, policyMap[policy.ID][ruleID].sourcePeers) + a.processGroups(rule.Destinations, policyMap[policy.ID][ruleID].destinationPeers) + } + } + return policyMap +} + +func (a *Account) processGroups(groupIDs []string, peerMap peerMap) { + for _, pid := range groupIDs { + p, ok := a.Peers[pid] + if ok { + peerMap[pid] = p + } + } +} diff --git a/management/server/policy_test.go b/management/server/policy_test.go index bf9a53d16..cb3f426c8 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -143,15 +143,17 @@ func TestAccount_getPeersByPolicy(t *testing.T) { } t.Run("check that all peers get map", func(t *testing.T) { + policyExpandedPeers := account.getPolicyExpandedPeers() for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers, policyExpandedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers) + policyExpandedPeers := account.getPolicyExpandedPeers() + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers, policyExpandedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -387,7 +389,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + policyExpandedPeers := account.getPolicyExpandedPeers() + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -415,7 +418,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + policyExpandedPeers := account.getPolicyExpandedPeers() + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -445,7 +449,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + policyExpandedPeers := account.getPolicyExpandedPeers() + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -466,7 +471,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + policyExpandedPeers := account.getPolicyExpandedPeers() + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -661,9 +667,10 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { approvedPeers[p] = struct{}{} } t.Run("verify peer's network map with default group peer list", func(t *testing.T) { + policyExpandedPeers := account.getPolicyExpandedPeers() // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -673,7 +680,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) expectedFirewallRules := []*FirewallRule{ @@ -689,7 +696,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers, policyExpandedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -699,7 +706,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers, policyExpandedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -711,22 +718,23 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with modified group peer list", func(t *testing.T) { // Removing peerB as the part of destination group Swarm account.Groups["GroupSwarm"].Peers = []string{"peerA", "peerD", "peerE", "peerG", "peerH"} + policyExpandedPeers := account.getPolicyExpandedPeers() // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers, policyExpandedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -738,17 +746,18 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // Removing peerF as the part of source group All account.Groups["GroupAll"].Peers = []string{"peerB", "peerA", "peerD", "peerC", "peerG", "peerH"} + policyExpandedPeers = account.getPolicyExpandedPeers() // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers, policyExpandedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers, policyExpandedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"])