From 1a623943c88e872e746d3248a6a2ce963b997bcc Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 30 Dec 2024 12:40:24 +0100 Subject: [PATCH] [management] Fix networks net map generation with posture checks (#3124) --- go.mod | 1 + go.sum | 2 + management/server/account_test.go | 6 +- management/server/peer_test.go | 6 +- management/server/policy_test.go | 32 +- management/server/route.go | 2 +- management/server/types/account.go | 140 +++++---- management/server/types/account_test.go | 382 ++++++++++++++++++++++++ management/server/types/policy.go | 12 +- 9 files changed, 508 insertions(+), 75 deletions(-) diff --git a/go.mod b/go.mod index d48280df0..330d0763f 100644 --- a/go.mod +++ b/go.mod @@ -79,6 +79,7 @@ require ( github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/things-go/go-socks5 v0.0.4 + github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 diff --git a/go.sum b/go.sum index 540cbf20b..ea4597836 100644 --- a/go.sum +++ b/go.sum @@ -698,6 +698,8 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 h1:S5h7yNKStqF8CqFtgtMNMzk/lUI3p82LrX6h2BhlsTM= +github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907/go.mod h1:/7Fy/4/OyrkguTf2i2pO4erUD/8QAlrlmXSdSJPu678= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/account_test.go b/management/server/account_test.go index d83eab6d1..280d998fd 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3037,9 +3037,9 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 1, 3, 3, 10}, + {"Small", 50, 5, 1, 3, 3, 11}, {"Medium", 500, 100, 7, 13, 10, 70}, - {"Large", 5000, 200, 65, 80, 60, 200}, + {"Large", 5000, 200, 65, 80, 60, 220}, {"Small single", 50, 10, 1, 3, 3, 70}, {"Medium single", 500, 10, 7, 13, 10, 26}, {"Large 5", 5000, 15, 65, 80, 60, 200}, @@ -3179,7 +3179,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { maxMsPerOpCICD float64 }{ {"Small", 50, 5, 107, 120, 107, 160}, - {"Medium", 500, 100, 105, 140, 105, 190}, + {"Medium", 500, 100, 105, 140, 105, 220}, {"Large", 5000, 200, 180, 220, 180, 350}, {"Small single", 50, 10, 107, 120, 105, 160}, {"Medium single", 500, 10, 105, 140, 105, 170}, diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 2ab262ff0..9ad67d2bf 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -932,11 +932,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { }{ {"Small", 50, 5, 90, 120, 90, 120}, {"Medium", 500, 100, 110, 150, 120, 260}, - {"Large", 5000, 200, 800, 1390, 2500, 4600}, + {"Large", 5000, 200, 800, 1700, 2500, 5000}, {"Small single", 50, 10, 90, 120, 90, 120}, {"Medium single", 500, 10, 110, 170, 120, 200}, - {"Large 5", 5000, 15, 1300, 2100, 5000, 7000}, - {"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000}, + {"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, + {"Extra Large", 2000, 2000, 1300, 2400, 4000, 6400}, } log.SetOutput(io.Discard) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index fab738abe..0d17da23a 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -74,6 +74,19 @@ func TestAccount_getPeersByPolicy(t *testing.T) { "peerH", }, }, + "GroupWorkstations": { + ID: "GroupWorkstations", + Name: "All", + Peers: []string{ + "peerB", + "peerA", + "peerD", + "peerE", + "peerF", + "peerG", + "peerH", + }, + }, "GroupSwarm": { ID: "GroupSwarm", Name: "swarm", @@ -127,7 +140,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", - "GroupAll", + "GroupWorkstations", }, Destinations: []string{ "GroupSwarm", @@ -159,6 +172,8 @@ func TestAccount_getPeersByPolicy(t *testing.T) { assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerE"]) assert.Contains(t, peers, account.Peers["peerF"]) + assert.Contains(t, peers, account.Peers["peerG"]) + assert.Contains(t, peers, account.Peers["peerH"]) epectedFirewallRules := []*types.FirewallRule{ { @@ -189,21 +204,6 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Protocol: "all", Port: "", }, - { - PeerIP: "100.65.254.139", - Direction: types.FirewallRuleDirectionOUT, - Action: "accept", - Protocol: "all", - Port: "", - }, - { - PeerIP: "100.65.254.139", - Direction: types.FirewallRuleDirectionIN, - Action: "accept", - Protocol: "all", - Port: "", - }, - { PeerIP: "100.65.62.5", Direction: types.FirewallRuleDirectionOUT, diff --git a/management/server/route.go b/management/server/route.go index 1eb51aea7..b6b44fbbd 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -364,7 +364,7 @@ func toProtocolRoute(route *route.Route) *proto.Route { } func toProtocolRoutes(routes []*route.Route) []*proto.Route { - protoRoutes := make([]*proto.Route, 0) + protoRoutes := make([]*proto.Route, 0, len(routes)) for _, r := range routes { protoRoutes = append(protoRoutes, toProtocolRoute(r)) } diff --git a/management/server/types/account.go b/management/server/types/account.go index b36b719e4..3ef862fa6 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/yourbasic/radix" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -1045,37 +1046,32 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false - filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) - for _, g := range groups { - group, ok := a.Groups[g] - if !ok { + uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups) + filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) + for _, p := range uniquePeerIDs { + peer, ok := a.Peers[p] + if !ok || peer == nil { continue } - for _, p := range group.Peers { - peer, ok := a.Peers[p] - if !ok || peer == nil { - continue - } - - // validate the peer based on policy posture checks applied - isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) - if !isValid { - continue - } - - if _, ok := validatedPeersMap[peer.ID]; !ok { - continue - } - - if peer.ID == peerID { - peerInGroups = true - continue - } - - filteredPeers = append(filteredPeers, peer) + // validate the peer based on policy posture checks applied + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid { + continue } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peer) } + return filteredPeers, peerInGroups } @@ -1151,7 +1147,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli continue } - rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) + rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap) rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) fwRules = append(fwRules, rules...) } @@ -1159,8 +1155,8 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli return fwRules } -func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { - distPeersWithPolicy := make(map[string]struct{}) +func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make([]string, 0) for _, id := range rule.Sources { group := a.Groups[id] if group == nil { @@ -1173,14 +1169,17 @@ func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeer } _, distPeer := distributionPeers[pID] _, valid := validatedPeersMap[pID] - if distPeer && valid { - distPeersWithPolicy[pID] = struct{}{} + if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { + distPeersWithPolicy = append(distPeersWithPolicy, pID) } } } - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) - for pID := range distPeersWithPolicy { + radix.Sort(distPeersWithPolicy) + uniqueDistributionPeers := slices.Compact(distPeersWithPolicy) + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(uniqueDistributionPeers)) + for _, pID := range uniqueDistributionPeers { peer := a.Peers[pID] if peer == nil { continue @@ -1271,7 +1270,11 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) - routesFirewallRules = append(routesFirewallRules, rules...) + for _, rule := range rules { + if len(rule.SourceRanges) > 0 { + routesFirewallRules = append(routesFirewallRules, rule) + } + } } return routesFirewallRules @@ -1306,7 +1309,7 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { var isRoutingPeer bool var routes []*route.Route - var allSourcePeers []string + allSourcePeers := make([]string, 0) for _, resource := range a.NetworkResources { var addSourcePeers bool @@ -1319,28 +1322,63 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st } } + addedResourceRoute := false for _, policy := range resourcePolicies[resource.ID] { - for _, sourceGroup := range policy.SourceGroups() { - group := a.GetGroup(sourceGroup) - if group == nil { - log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id) - continue - } - - // routing peer should be able to connect with all source peers - if addSourcePeers { - allSourcePeers = append(allSourcePeers, group.Peers...) - } else if slices.Contains(group.Peers, peerID) { - // add routes for the resource if the peer is in the distribution group - for peerId, router := range networkRoutingPeers { - routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) - } + peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + if addSourcePeers { + allSourcePeers = append(allSourcePeers, a.getPostureValidPeers(peers, policy.SourcePostureChecks)...) + } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + // add routes for the resource if the peer is in the distribution group + for peerId, router := range networkRoutingPeers { + routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) } + addedResourceRoute = true + } + if addedResourceRoute { + break } } } - return isRoutingPeer, routes, allSourcePeers + radix.Sort(allSourcePeers) + return isRoutingPeer, routes, slices.Compact(allSourcePeers) +} + +func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { + var dest []string + for _, peerID := range inputPeers { + if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) { + dest = append(dest, peerID) + } + } + return dest +} + +func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { + gObjs := make([]*Group, 0, len(groups)) + tp := 0 + for _, groupID := range groups { + group := a.GetGroup(groupID) + if group == nil { + log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id) + continue + } + + if group.IsGroupAll() || len(groups) == 1 { + return group.Peers + } + + gObjs = append(gObjs, group) + tp += len(group.Peers) + } + + ids := make([]string, 0, tp) + for _, group := range gObjs { + ids = append(ids, group.Peers...) + } + + radix.Sort(ids) + return slices.Compact(ids) } // getNetworkResources filters and returns a list of network resources associated with the given network ID. diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index c73421d16..efe930108 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -1,14 +1,20 @@ package types import ( + "context" + "net" + "net/netip" + "slices" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" ) @@ -373,3 +379,379 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) require.Len(t, result, 0) } + +const ( + accID = "accountID" + network1ID = "network1ID" + group1ID = "group1" + accNetResourcePeer1ID = "peer1" + accNetResourcePeer2ID = "peer2" + accNetResourceRouter1ID = "router1" + accNetResource1ID = "resource1ID" + accNetResourceRestrictPostureCheckID = "restrictPostureCheck" + accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck" + accNetResourceLockedPostureCheckID = "lockedPostureCheck" + accNetResourceLinuxPostureCheckID = "linuxPostureCheck" +) + +var ( + accNetResourcePeer1IP = net.IP{192, 168, 1, 1} + accNetResourcePeer2IP = net.IP{192, 168, 1, 2} + accNetResourceRouter1IP = net.IP{192, 168, 1, 3} + accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}} +) + +func getBasicAccountsWithResource() *Account { + return &Account{ + Id: accID, + Peers: map[string]*nbpeer.Peer{ + accNetResourcePeer1ID: { + ID: accNetResourcePeer1ID, + AccountID: accID, + Key: "peer1Key", + IP: accNetResourcePeer1IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + WtVersion: "0.35.1", + KernelVersion: "4.4.0", + }, + }, + accNetResourcePeer2ID: { + ID: accNetResourcePeer2ID, + AccountID: accID, + Key: "peer2Key", + IP: accNetResourcePeer2IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "windows", + WtVersion: "0.34.1", + KernelVersion: "4.4.0", + }, + }, + accNetResourceRouter1ID: { + ID: accNetResourceRouter1ID, + AccountID: accID, + Key: "router1Key", + IP: accNetResourceRouter1IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + WtVersion: "0.35.1", + KernelVersion: "4.4.0", + }, + }, + }, + Groups: map[string]*Group{ + group1ID: { + ID: group1ID, + Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID}, + }, + }, + Networks: []*networkTypes.Network{ + { + ID: network1ID, + AccountID: accID, + Name: "network1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: accNetResourceRouter1ID, + NetworkID: network1ID, + AccountID: accID, + Peer: accNetResourceRouter1ID, + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: accNetResource1ID, + AccountID: accID, + NetworkID: network1ID, + Address: "10.10.10.0/24", + Prefix: netip.MustParsePrefix("10.10.10.0/24"), + Type: resourceTypes.NetworkResourceType("subnet"), + }, + }, + Policies: []*Policy{ + { + ID: "policy1ID", + AccountID: accID, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule1ID", + Enabled: true, + Sources: []string{group1ID}, + DestinationResource: Resource{ + ID: accNetResource1ID, + Type: "Host", + }, + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"80"}, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: nil, + }, + }, + PostureChecks: []*posture.Checks{ + { + ID: accNetResourceRestrictPostureCheckID, + Name: accNetResourceRestrictPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.35.0", + }, + }, + }, + { + ID: accNetResourceRelaxedPostureCheckID, + Name: accNetResourceRelaxedPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", + }, + }, + }, + { + ID: accNetResourceLockedPostureCheckID, + Name: accNetResourceLockedPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "7.7.7", + }, + }, + }, + { + ID: accNetResourceLinuxPostureCheckID, + Name: accNetResourceLinuxPostureCheckID, + Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{ + Linux: &posture.MinKernelVersionCheck{ + MinKernelVersion: "0.0.0"}, + }, + }, + }, + }, + } +} + +func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // all peers should match the policy + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 2, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should allow peer1 to match the policy + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 1, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should not match any peer + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 0, "expected rules count don't match") +} + +func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should allow peer1 to match the policy + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} + + // should allow peer1 and peer2 to match the policy + newPolicy := &Policy{ + ID: "policy2ID", + AccountID: accID, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "policy2ID", + Enabled: true, + Sources: []string{group1ID}, + DestinationResource: Resource{ + ID: accNetResource1ID, + Type: "Host", + }, + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID}, + } + + account.Policies = append(account.Policies, newPolicy) + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 2, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 2, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } + + assert.Equal(t, uint16(22), rules[1].Port, "should have port 22") + assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp") + if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String()) + } + if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // two posture checks should match only the peers that match both checks + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 1, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} diff --git a/management/server/types/policy.go b/management/server/types/policy.go index c2b82d68a..5b2cf06a0 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -1,5 +1,11 @@ package types +import ( + "slices" + + "github.com/yourbasic/radix" +) + const ( // PolicyTrafficActionAccept indicates that the traffic is accepted PolicyTrafficActionAccept = PolicyTrafficActionType("accept") @@ -117,9 +123,13 @@ func (p *Policy) RuleGroups() []string { // SourceGroups returns a slice of all unique source groups referenced in the policy's rules. func (p *Policy) SourceGroups() []string { + if len(p.Rules) == 1 { + return p.Rules[0].Sources + } groups := make([]string, 0) for _, rule := range p.Rules { groups = append(groups, rule.Sources...) } - return groups + radix.Sort(groups) + return slices.Compact(groups) }