From 445b626dc8cf1f8489291bc5da59bcfdf89391ad Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 27 Dec 2024 14:39:34 +0300 Subject: [PATCH 01/10] [management] Add missing group usage checks for network resources and routes access control (#3117) * Prevent deletion of groups linked to routes access control groups Signed-off-by: bcmmbaga * Prevent deletion of groups linked to network resource Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/group.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/management/server/group.go b/management/server/group.go index d433a3485..f1057dda6 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -474,6 +474,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") } + if len(group.Resources) > 0 { + return &GroupLinkError{"network resource", group.Resources[0].ID} + } + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } @@ -529,7 +533,10 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI } for _, r := range routes { - if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { + isLinked := slices.Contains(r.Groups, groupID) || + slices.Contains(r.PeerGroups, groupID) || + slices.Contains(r.AccessControlGroups, groupID) + if isLinked { return true, r } } From fbce8bb51197c2e29a168386bf3ba2da7fa16c1a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:13:36 +0100 Subject: [PATCH 02/10] [management] remove ids from policy creation api (#2997) --- management/server/http/api/openapi.yml | 62 ++++++++++++++++--- management/server/http/api/types.gen.go | 37 ++++++----- .../handlers/policies/policies_handler.go | 11 +++- .../policies/policies_handler_test.go | 11 ++-- 4 files changed, 90 insertions(+), 31 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 351976baf..6c1d6b424 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -725,10 +725,6 @@ components: PolicyRuleMinimum: type: object properties: - id: - description: Policy rule ID - type: string - example: ch8i4ug6lnn4g9hqv7mg name: description: Policy rule name identifier type: string @@ -790,6 +786,31 @@ components: - end PolicyRuleUpdate: + allOf: + - $ref: '#/components/schemas/PolicyRuleMinimum' + - type: object + properties: + id: + description: Policy rule ID + type: string + example: ch8i4ug6lnn4g9hqv7mg + sources: + description: Policy rule source group IDs + type: array + items: + type: string + example: "ch8i4ug6lnn4g9hqv797" + destinations: + description: Policy rule destination group IDs + type: array + items: + type: string + example: "ch8i4ug6lnn4g9h7v7m0" + required: + - sources + - destinations + + PolicyRuleCreate: allOf: - $ref: '#/components/schemas/PolicyRuleMinimum' - type: object @@ -817,6 +838,10 @@ components: - $ref: '#/components/schemas/PolicyRuleMinimum' - type: object properties: + id: + description: Policy rule ID + type: string + example: ch8i4ug6lnn4g9hqv7mg sources: description: Policy rule source group IDs type: array @@ -836,10 +861,6 @@ components: PolicyMinimum: type: object properties: - id: - description: Policy ID - type: string - example: ch8i4ug6lnn4g9hqv7mg name: description: Policy name identifier type: string @@ -854,7 +875,6 @@ components: example: true required: - name - - description - enabled PolicyUpdate: allOf: @@ -874,11 +894,33 @@ components: $ref: '#/components/schemas/PolicyRuleUpdate' required: - rules + PolicyCreate: + allOf: + - $ref: '#/components/schemas/PolicyMinimum' + - type: object + properties: + source_posture_checks: + description: Posture checks ID's applied to policy source groups + type: array + items: + type: string + example: "chacdk86lnnboviihd70" + rules: + description: Policy rule object for policy UI editor + type: array + items: + $ref: '#/components/schemas/PolicyRuleUpdate' + required: + - rules Policy: allOf: - $ref: '#/components/schemas/PolicyMinimum' - type: object properties: + id: + description: Policy ID + type: string + example: ch8i4ug6lnn4g9hqv7mg source_posture_checks: description: Posture checks ID's applied to policy source groups type: array @@ -2463,7 +2505,7 @@ paths: content: 'application/json': schema: - $ref: '#/components/schemas/PolicyUpdate' + $ref: '#/components/schemas/PolicyCreate' responses: '200': description: A Policy object diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 40574d6f1..83226587f 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -879,7 +879,7 @@ type PersonalAccessTokenRequest struct { // Policy defines model for Policy. type Policy struct { // Description Policy friendly description - Description string `json:"description"` + Description *string `json:"description,omitempty"` // Enabled Policy status Enabled bool `json:"enabled"` @@ -897,16 +897,31 @@ type Policy struct { SourcePostureChecks []string `json:"source_posture_checks"` } -// PolicyMinimum defines model for PolicyMinimum. -type PolicyMinimum struct { +// PolicyCreate defines model for PolicyCreate. +type PolicyCreate struct { // Description Policy friendly description - Description string `json:"description"` + Description *string `json:"description,omitempty"` // Enabled Policy status Enabled bool `json:"enabled"` - // Id Policy ID - Id *string `json:"id,omitempty"` + // Name Policy name identifier + Name string `json:"name"` + + // Rules Policy rule object for policy UI editor + Rules []PolicyRuleUpdate `json:"rules"` + + // SourcePostureChecks Posture checks ID's applied to policy source groups + SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"` +} + +// PolicyMinimum defines model for PolicyMinimum. +type PolicyMinimum struct { + // Description Policy friendly description + Description *string `json:"description,omitempty"` + + // Enabled Policy status + Enabled bool `json:"enabled"` // Name Policy name identifier Name string `json:"name"` @@ -970,9 +985,6 @@ type PolicyRuleMinimum struct { // Enabled Policy rule status Enabled bool `json:"enabled"` - // Id Policy rule ID - Id *string `json:"id,omitempty"` - // Name Policy rule name identifier Name string `json:"name"` @@ -1039,14 +1051,11 @@ type PolicyRuleUpdateProtocol string // PolicyUpdate defines model for PolicyUpdate. type PolicyUpdate struct { // Description Policy friendly description - Description string `json:"description"` + Description *string `json:"description,omitempty"` // Enabled Policy status Enabled bool `json:"enabled"` - // Id Policy ID - Id *string `json:"id,omitempty"` - // Name Policy name identifier Name string `json:"name"` @@ -1473,7 +1482,7 @@ type PutApiPeersPeerIdJSONRequestBody = PeerRequest type PostApiPoliciesJSONRequestBody = PolicyUpdate // PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType. -type PutApiPoliciesPolicyIdJSONRequestBody = PolicyUpdate +type PutApiPoliciesPolicyIdJSONRequestBody = PolicyCreate // PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType. type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index d538d07db..b1035c570 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -133,16 +133,21 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s return } + description := "" + if req.Description != nil { + description = *req.Description + } + policy := &types.Policy{ ID: policyID, AccountID: accountID, Name: req.Name, Enabled: req.Enabled, - Description: req.Description, + Description: description, } for _, rule := range req.Rules { var ruleID string - if rule.Id != nil { + if rule.Id != nil && policyID != "" { ruleID = *rule.Id } @@ -370,7 +375,7 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { ap := &api.Policy{ Id: &policy.ID, Name: policy.Name, - Description: policy.Description, + Description: &policy.Description, Enabled: policy.Enabled, SourcePostureChecks: policy.SourcePostureChecks, } diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index 956d0b7cd..3e1be187c 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -154,6 +154,7 @@ func TestPoliciesGetPolicy(t *testing.T) { func TestPoliciesWritePolicy(t *testing.T) { str := func(s string) *string { return &s } + emptyString := "" tt := []struct { name string expectedStatus int @@ -184,8 +185,9 @@ func TestPoliciesWritePolicy(t *testing.T) { expectedStatus: http.StatusOK, expectedBody: true, expectedPolicy: &api.Policy{ - Id: str("id-was-set"), - Name: "Default POSTed Policy", + Id: str("id-was-set"), + Name: "Default POSTed Policy", + Description: &emptyString, Rules: []api.PolicyRule{ { Id: str("id-was-set"), @@ -232,8 +234,9 @@ func TestPoliciesWritePolicy(t *testing.T) { expectedStatus: http.StatusOK, expectedBody: true, expectedPolicy: &api.Policy{ - Id: str("id-existed"), - Name: "Default POSTed Policy", + Id: str("id-existed"), + Name: "Default POSTed Policy", + Description: &emptyString, Rules: []api.PolicyRule{ { Id: str("id-existed"), From 1a623943c88e872e746d3248a6a2ce963b997bcc Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 30 Dec 2024 12:40:24 +0100 Subject: [PATCH 03/10] [management] Fix networks net map generation with posture checks (#3124) --- go.mod | 1 + go.sum | 2 + management/server/account_test.go | 6 +- management/server/peer_test.go | 6 +- management/server/policy_test.go | 32 +- management/server/route.go | 2 +- management/server/types/account.go | 140 +++++---- management/server/types/account_test.go | 382 ++++++++++++++++++++++++ management/server/types/policy.go | 12 +- 9 files changed, 508 insertions(+), 75 deletions(-) diff --git a/go.mod b/go.mod index d48280df0..330d0763f 100644 --- a/go.mod +++ b/go.mod @@ -79,6 +79,7 @@ require ( github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/things-go/go-socks5 v0.0.4 + github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 diff --git a/go.sum b/go.sum index 540cbf20b..ea4597836 100644 --- a/go.sum +++ b/go.sum @@ -698,6 +698,8 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 h1:S5h7yNKStqF8CqFtgtMNMzk/lUI3p82LrX6h2BhlsTM= +github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907/go.mod h1:/7Fy/4/OyrkguTf2i2pO4erUD/8QAlrlmXSdSJPu678= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/account_test.go b/management/server/account_test.go index d83eab6d1..280d998fd 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3037,9 +3037,9 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 1, 3, 3, 10}, + {"Small", 50, 5, 1, 3, 3, 11}, {"Medium", 500, 100, 7, 13, 10, 70}, - {"Large", 5000, 200, 65, 80, 60, 200}, + {"Large", 5000, 200, 65, 80, 60, 220}, {"Small single", 50, 10, 1, 3, 3, 70}, {"Medium single", 500, 10, 7, 13, 10, 26}, {"Large 5", 5000, 15, 65, 80, 60, 200}, @@ -3179,7 +3179,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { maxMsPerOpCICD float64 }{ {"Small", 50, 5, 107, 120, 107, 160}, - {"Medium", 500, 100, 105, 140, 105, 190}, + {"Medium", 500, 100, 105, 140, 105, 220}, {"Large", 5000, 200, 180, 220, 180, 350}, {"Small single", 50, 10, 107, 120, 105, 160}, {"Medium single", 500, 10, 105, 140, 105, 170}, diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 2ab262ff0..9ad67d2bf 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -932,11 +932,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { }{ {"Small", 50, 5, 90, 120, 90, 120}, {"Medium", 500, 100, 110, 150, 120, 260}, - {"Large", 5000, 200, 800, 1390, 2500, 4600}, + {"Large", 5000, 200, 800, 1700, 2500, 5000}, {"Small single", 50, 10, 90, 120, 90, 120}, {"Medium single", 500, 10, 110, 170, 120, 200}, - {"Large 5", 5000, 15, 1300, 2100, 5000, 7000}, - {"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000}, + {"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, + {"Extra Large", 2000, 2000, 1300, 2400, 4000, 6400}, } log.SetOutput(io.Discard) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index fab738abe..0d17da23a 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -74,6 +74,19 @@ func TestAccount_getPeersByPolicy(t *testing.T) { "peerH", }, }, + "GroupWorkstations": { + ID: "GroupWorkstations", + Name: "All", + Peers: []string{ + "peerB", + "peerA", + "peerD", + "peerE", + "peerF", + "peerG", + "peerH", + }, + }, "GroupSwarm": { ID: "GroupSwarm", Name: "swarm", @@ -127,7 +140,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", - "GroupAll", + "GroupWorkstations", }, Destinations: []string{ "GroupSwarm", @@ -159,6 +172,8 @@ func TestAccount_getPeersByPolicy(t *testing.T) { assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerE"]) assert.Contains(t, peers, account.Peers["peerF"]) + assert.Contains(t, peers, account.Peers["peerG"]) + assert.Contains(t, peers, account.Peers["peerH"]) epectedFirewallRules := []*types.FirewallRule{ { @@ -189,21 +204,6 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Protocol: "all", Port: "", }, - { - PeerIP: "100.65.254.139", - Direction: types.FirewallRuleDirectionOUT, - Action: "accept", - Protocol: "all", - Port: "", - }, - { - PeerIP: "100.65.254.139", - Direction: types.FirewallRuleDirectionIN, - Action: "accept", - Protocol: "all", - Port: "", - }, - { PeerIP: "100.65.62.5", Direction: types.FirewallRuleDirectionOUT, diff --git a/management/server/route.go b/management/server/route.go index 1eb51aea7..b6b44fbbd 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -364,7 +364,7 @@ func toProtocolRoute(route *route.Route) *proto.Route { } func toProtocolRoutes(routes []*route.Route) []*proto.Route { - protoRoutes := make([]*proto.Route, 0) + protoRoutes := make([]*proto.Route, 0, len(routes)) for _, r := range routes { protoRoutes = append(protoRoutes, toProtocolRoute(r)) } diff --git a/management/server/types/account.go b/management/server/types/account.go index b36b719e4..3ef862fa6 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/yourbasic/radix" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -1045,37 +1046,32 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false - filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) - for _, g := range groups { - group, ok := a.Groups[g] - if !ok { + uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups) + filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) + for _, p := range uniquePeerIDs { + peer, ok := a.Peers[p] + if !ok || peer == nil { continue } - for _, p := range group.Peers { - peer, ok := a.Peers[p] - if !ok || peer == nil { - continue - } - - // validate the peer based on policy posture checks applied - isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) - if !isValid { - continue - } - - if _, ok := validatedPeersMap[peer.ID]; !ok { - continue - } - - if peer.ID == peerID { - peerInGroups = true - continue - } - - filteredPeers = append(filteredPeers, peer) + // validate the peer based on policy posture checks applied + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid { + continue } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peer) } + return filteredPeers, peerInGroups } @@ -1151,7 +1147,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli continue } - rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) + rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap) rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) fwRules = append(fwRules, rules...) } @@ -1159,8 +1155,8 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli return fwRules } -func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { - distPeersWithPolicy := make(map[string]struct{}) +func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make([]string, 0) for _, id := range rule.Sources { group := a.Groups[id] if group == nil { @@ -1173,14 +1169,17 @@ func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeer } _, distPeer := distributionPeers[pID] _, valid := validatedPeersMap[pID] - if distPeer && valid { - distPeersWithPolicy[pID] = struct{}{} + if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { + distPeersWithPolicy = append(distPeersWithPolicy, pID) } } } - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) - for pID := range distPeersWithPolicy { + radix.Sort(distPeersWithPolicy) + uniqueDistributionPeers := slices.Compact(distPeersWithPolicy) + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(uniqueDistributionPeers)) + for _, pID := range uniqueDistributionPeers { peer := a.Peers[pID] if peer == nil { continue @@ -1271,7 +1270,11 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) - routesFirewallRules = append(routesFirewallRules, rules...) + for _, rule := range rules { + if len(rule.SourceRanges) > 0 { + routesFirewallRules = append(routesFirewallRules, rule) + } + } } return routesFirewallRules @@ -1306,7 +1309,7 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { var isRoutingPeer bool var routes []*route.Route - var allSourcePeers []string + allSourcePeers := make([]string, 0) for _, resource := range a.NetworkResources { var addSourcePeers bool @@ -1319,28 +1322,63 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st } } + addedResourceRoute := false for _, policy := range resourcePolicies[resource.ID] { - for _, sourceGroup := range policy.SourceGroups() { - group := a.GetGroup(sourceGroup) - if group == nil { - log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id) - continue - } - - // routing peer should be able to connect with all source peers - if addSourcePeers { - allSourcePeers = append(allSourcePeers, group.Peers...) - } else if slices.Contains(group.Peers, peerID) { - // add routes for the resource if the peer is in the distribution group - for peerId, router := range networkRoutingPeers { - routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) - } + peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + if addSourcePeers { + allSourcePeers = append(allSourcePeers, a.getPostureValidPeers(peers, policy.SourcePostureChecks)...) + } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + // add routes for the resource if the peer is in the distribution group + for peerId, router := range networkRoutingPeers { + routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) } + addedResourceRoute = true + } + if addedResourceRoute { + break } } } - return isRoutingPeer, routes, allSourcePeers + radix.Sort(allSourcePeers) + return isRoutingPeer, routes, slices.Compact(allSourcePeers) +} + +func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { + var dest []string + for _, peerID := range inputPeers { + if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) { + dest = append(dest, peerID) + } + } + return dest +} + +func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { + gObjs := make([]*Group, 0, len(groups)) + tp := 0 + for _, groupID := range groups { + group := a.GetGroup(groupID) + if group == nil { + log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id) + continue + } + + if group.IsGroupAll() || len(groups) == 1 { + return group.Peers + } + + gObjs = append(gObjs, group) + tp += len(group.Peers) + } + + ids := make([]string, 0, tp) + for _, group := range gObjs { + ids = append(ids, group.Peers...) + } + + radix.Sort(ids) + return slices.Compact(ids) } // getNetworkResources filters and returns a list of network resources associated with the given network ID. diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index c73421d16..efe930108 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -1,14 +1,20 @@ package types import ( + "context" + "net" + "net/netip" + "slices" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" ) @@ -373,3 +379,379 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) require.Len(t, result, 0) } + +const ( + accID = "accountID" + network1ID = "network1ID" + group1ID = "group1" + accNetResourcePeer1ID = "peer1" + accNetResourcePeer2ID = "peer2" + accNetResourceRouter1ID = "router1" + accNetResource1ID = "resource1ID" + accNetResourceRestrictPostureCheckID = "restrictPostureCheck" + accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck" + accNetResourceLockedPostureCheckID = "lockedPostureCheck" + accNetResourceLinuxPostureCheckID = "linuxPostureCheck" +) + +var ( + accNetResourcePeer1IP = net.IP{192, 168, 1, 1} + accNetResourcePeer2IP = net.IP{192, 168, 1, 2} + accNetResourceRouter1IP = net.IP{192, 168, 1, 3} + accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}} +) + +func getBasicAccountsWithResource() *Account { + return &Account{ + Id: accID, + Peers: map[string]*nbpeer.Peer{ + accNetResourcePeer1ID: { + ID: accNetResourcePeer1ID, + AccountID: accID, + Key: "peer1Key", + IP: accNetResourcePeer1IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + WtVersion: "0.35.1", + KernelVersion: "4.4.0", + }, + }, + accNetResourcePeer2ID: { + ID: accNetResourcePeer2ID, + AccountID: accID, + Key: "peer2Key", + IP: accNetResourcePeer2IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "windows", + WtVersion: "0.34.1", + KernelVersion: "4.4.0", + }, + }, + accNetResourceRouter1ID: { + ID: accNetResourceRouter1ID, + AccountID: accID, + Key: "router1Key", + IP: accNetResourceRouter1IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + WtVersion: "0.35.1", + KernelVersion: "4.4.0", + }, + }, + }, + Groups: map[string]*Group{ + group1ID: { + ID: group1ID, + Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID}, + }, + }, + Networks: []*networkTypes.Network{ + { + ID: network1ID, + AccountID: accID, + Name: "network1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: accNetResourceRouter1ID, + NetworkID: network1ID, + AccountID: accID, + Peer: accNetResourceRouter1ID, + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: accNetResource1ID, + AccountID: accID, + NetworkID: network1ID, + Address: "10.10.10.0/24", + Prefix: netip.MustParsePrefix("10.10.10.0/24"), + Type: resourceTypes.NetworkResourceType("subnet"), + }, + }, + Policies: []*Policy{ + { + ID: "policy1ID", + AccountID: accID, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule1ID", + Enabled: true, + Sources: []string{group1ID}, + DestinationResource: Resource{ + ID: accNetResource1ID, + Type: "Host", + }, + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"80"}, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: nil, + }, + }, + PostureChecks: []*posture.Checks{ + { + ID: accNetResourceRestrictPostureCheckID, + Name: accNetResourceRestrictPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.35.0", + }, + }, + }, + { + ID: accNetResourceRelaxedPostureCheckID, + Name: accNetResourceRelaxedPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", + }, + }, + }, + { + ID: accNetResourceLockedPostureCheckID, + Name: accNetResourceLockedPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "7.7.7", + }, + }, + }, + { + ID: accNetResourceLinuxPostureCheckID, + Name: accNetResourceLinuxPostureCheckID, + Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{ + Linux: &posture.MinKernelVersionCheck{ + MinKernelVersion: "0.0.0"}, + }, + }, + }, + }, + } +} + +func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // all peers should match the policy + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 2, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should allow peer1 to match the policy + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 1, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should not match any peer + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 0, "expected rules count don't match") +} + +func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should allow peer1 to match the policy + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} + + // should allow peer1 and peer2 to match the policy + newPolicy := &Policy{ + ID: "policy2ID", + AccountID: accID, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "policy2ID", + Enabled: true, + Sources: []string{group1ID}, + DestinationResource: Resource{ + ID: accNetResource1ID, + Type: "Host", + }, + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID}, + } + + account.Policies = append(account.Policies, newPolicy) + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 2, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 2, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } + + assert.Equal(t, uint16(22), rules[1].Port, "should have port 22") + assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp") + if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String()) + } + if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // two posture checks should match only the peers that match both checks + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 1, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} diff --git a/management/server/types/policy.go b/management/server/types/policy.go index c2b82d68a..5b2cf06a0 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -1,5 +1,11 @@ package types +import ( + "slices" + + "github.com/yourbasic/radix" +) + const ( // PolicyTrafficActionAccept indicates that the traffic is accepted PolicyTrafficActionAccept = PolicyTrafficActionType("accept") @@ -117,9 +123,13 @@ func (p *Policy) RuleGroups() []string { // SourceGroups returns a slice of all unique source groups referenced in the policy's rules. func (p *Policy) SourceGroups() []string { + if len(p.Rules) == 1 { + return p.Rules[0].Sources + } groups := make([]string, 0) for _, rule := range p.Rules { groups = append(groups, rule.Sources...) } - return groups + radix.Sort(groups) + return slices.Compact(groups) } From 18316be09a1b8d1dfbaa942155e37c46f7375a24 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 30 Dec 2024 12:53:51 +0100 Subject: [PATCH 04/10] [management] add selfhosted metrics for networks (#3118) --- .github/workflows/golangci-lint.yml | 2 +- management/server/metrics/selfhosted.go | 18 ++++++++++ management/server/metrics/selfhosted_test.go | 37 ++++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 89defce32..6705a34ec 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd + ignore_words_list: erro,clienta,hastable,iif,groupd,testin skip: go.mod,go.sum only_warn: 1 golangci: diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 82b34393f..03cb21af1 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -195,6 +195,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties { groups int routes int routesWithRGGroups int + networks int + networkResources int + networkRouters int + networkRoutersWithPG int nameservers int uiClient int version string @@ -219,6 +223,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties { } groups += len(account.Groups) + networks += len(account.Networks) + networkResources += len(account.NetworkResources) + + networkRouters += len(account.NetworkRouters) + for _, router := range account.NetworkRouters { + if len(router.PeerGroups) > 0 { + networkRoutersWithPG++ + } + } + routes += len(account.Routes) for _, route := range account.Routes { if len(route.PeerGroups) > 0 { @@ -312,6 +326,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks metricsProperties["posture_checks"] = postureChecks metricsProperties["groups"] = groups + metricsProperties["networks"] = networks + metricsProperties["network_resources"] = networkResources + metricsProperties["network_routers"] = networkRouters + metricsProperties["network_routers_with_groups"] = networkRoutersWithPG metricsProperties["routes"] = routes metricsProperties["routes_with_routing_groups"] = routesWithRGGroups metricsProperties["nameservers"] = nameservers diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 1d356387f..4894c1ac4 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,6 +5,9 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -172,6 +175,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { }, }, }, + Networks: []*networkTypes.Network{ + { + ID: "1", + AccountID: "1", + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "1", + AccountID: "1", + NetworkID: "1", + }, + { + ID: "2", + AccountID: "1", + NetworkID: "1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "1", + AccountID: "1", + NetworkID: "1", + }, + }, }, } } @@ -200,6 +228,15 @@ func TestGenerateProperties(t *testing.T) { if properties["routes"] != 2 { t.Errorf("expected 2 routes, got %d", properties["routes"]) } + if properties["networks"] != 1 { + t.Errorf("expected 1 networks, got %d", properties["networks"]) + } + if properties["network_resources"] != 2 { + t.Errorf("expected 2 network_resources, got %d", properties["network_resources"]) + } + if properties["network_routers"] != 1 { + t.Errorf("expected 1 network_routers, got %d", properties["network_routers"]) + } if properties["rules"] != 4 { t.Errorf("expected 4 rules, got %d", properties["rules"]) } From 43ef64cf673fc785a2005f0c7b3f5616211f4065 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 31 Dec 2024 14:07:21 +0100 Subject: [PATCH 05/10] [client] Ignore case when matching domains in handler chain (#3133) --- client/internal/dns/handler_chain.go | 21 ++- client/internal/dns/handler_chain_test.go | 168 ++++++++++++++++++++++ 2 files changed, 178 insertions(+), 11 deletions(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 9302d50b1..5f63d1ab3 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority c.mu.Lock() defer c.mu.Unlock() + pattern = strings.ToLower(dns.Fqdn(pattern)) origPattern := pattern isWildcard := strings.HasPrefix(pattern, "*.") if isWildcard { pattern = pattern[2:] } - pattern = dns.Fqdn(pattern) - origPattern = dns.Fqdn(origPattern) - // First remove any existing handler with same original pattern and priority + // First remove any existing handler with same pattern (case-insensitive) and priority for i := len(c.handlers) - 1; i >= 0; i-- { - if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority { + if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority { if c.handlers[i].StopHandler != nil { c.handlers[i].StopHandler.stop() } @@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) { pattern = dns.Fqdn(pattern) - // Find and remove handlers matching both original pattern and priority + // Find and remove handlers matching both original pattern (case-insensitive) and priority for i := len(c.handlers) - 1; i >= 0; i-- { entry := c.handlers[i] - if entry.OrigPattern == pattern && entry.Priority == priority { + if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if entry.StopHandler != nil { entry.StopHandler.stop() } @@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool { c.mu.RLock() defer c.mu.RUnlock() - pattern = dns.Fqdn(pattern) + pattern = strings.ToLower(dns.Fqdn(pattern)) for _, entry := range c.handlers { - if entry.Pattern == pattern { + if strings.EqualFold(entry.Pattern, pattern) { return true } } @@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - qname := r.Question[0].Name + qname := strings.ToLower(r.Question[0].Name) log.Tracef("handling DNS request for domain=%s", qname) c.mu.RLock() @@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If handler wants subdomain matching, allow suffix match // Otherwise require exact match if entry.MatchSubdomains { - matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) + matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern) } else { - matched = qname == entry.Pattern + matched = strings.EqualFold(qname, entry.Pattern) } } diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 727b6e908..eb40c907f 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 4: Remove last handler chain.RemoveHandler(testDomain, nbdns.PriorityDefault) + assert.False(t, chain.HasHandlers(testDomain)) } + +func TestHandlerChain_CaseSensitivity(t *testing.T) { + tests := []struct { + name string + scenario string + addHandlers []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + } + query string + expectedCalls int + }{ + { + name: "case insensitive exact match", + scenario: "handler registered lowercase, query uppercase", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"example.com.", nbdns.PriorityDefault, false, true}, + }, + query: "EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "case insensitive wildcard match", + scenario: "handler registered mixed case wildcard, query different case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"*.Example.Com.", nbdns.PriorityDefault, false, true}, + }, + query: "sub.EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "multiple handlers different case same domain", + scenario: "second handler should replace first despite case difference", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, + {"example.com.", nbdns.PriorityDefault, false, true}, + }, + query: "ExAmPlE.cOm.", + expectedCalls: 1, + }, + { + name: "subdomain matching case insensitive", + scenario: "handler with MatchSubdomains true should match regardless of case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"example.com.", nbdns.PriorityDefault, true, true}, + }, + query: "SUB.EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "root zone case insensitive", + scenario: "root zone handler should match regardless of case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {".", nbdns.PriorityDefault, false, true}, + }, + query: "EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "multiple handlers different priority", + scenario: "should call higher priority handler despite case differences", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, + {"example.com.", nbdns.PriorityMatchDomain, false, false}, + {"Example.Com.", nbdns.PriorityDNSRoute, false, true}, + }, + query: "example.com.", + expectedCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + handlerCalls := make(map[string]bool) // track which patterns were called + + // Add handlers according to test case + for _, h := range tt.addHandlers { + var handler dns.Handler + pattern := h.pattern // capture pattern for closure + + if h.subdomains { + subHandler := &nbdns.MockSubdomainHandler{ + Subdomains: true, + } + if h.shouldMatch { + subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + handlerCalls[pattern] = true + w := args.Get(0).(dns.ResponseWriter) + r := args.Get(1).(*dns.Msg) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + } + handler = subHandler + } else { + mockHandler := &nbdns.MockHandler{} + if h.shouldMatch { + mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + handlerCalls[pattern] = true + w := args.Get(0).(dns.ResponseWriter) + r := args.Get(1).(*dns.Msg) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + } + handler = mockHandler + } + + chain.AddHandler(pattern, handler, h.priority, nil) + } + + // Execute request + r := new(dns.Msg) + r.SetQuestion(tt.query, dns.TypeA) + chain.ServeDNS(&mockResponseWriter{}, r) + + // Verify each handler was called exactly as expected + for _, h := range tt.addHandlers { + wasCalled := handlerCalls[h.pattern] + assert.Equal(t, h.shouldMatch, wasCalled, + "Handler for pattern %q was %s when it should%s have been", + h.pattern, + map[bool]string{true: "called", false: "not called"}[wasCalled], + map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch]) + } + + // Verify total number of calls + assert.Equal(t, tt.expectedCalls, len(handlerCalls), + "Wrong number of total handler calls") + }) + } +} From abbdf20f65f031fd6dc0a95b55a38044b158046b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 31 Dec 2024 14:08:48 +0100 Subject: [PATCH 06/10] [client] Allow inbound rosenpass port (#3109) --- client/firewall/iptables/manager_linux.go | 2 +- client/internal/dnsfwd/manager.go | 2 +- client/internal/engine.go | 45 +++++++++++++++++++---- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 0e1e5836f..da8e2c08f 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -197,7 +197,7 @@ func (m *Manager) AllowNetbird() error { } _, err := m.AddPeerFiltering( - net.ParseIP("0.0.0.0"), + net.IP{0, 0, 0, 0}, "all", nil, nil, diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 7cff6d517..f876bda30 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -83,7 +83,7 @@ func (h *Manager) allowDNSFirewall() error { IsRange: false, Values: []int{ListenPort}, } - dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") + dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") if err != nil { log.Errorf("failed to add allow DNS router rules, err: %v", err) return err diff --git a/client/internal/engine.go b/client/internal/engine.go index 042d384dc..896104df8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -406,13 +406,9 @@ func (e *Engine) Start() error { e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) if err != nil { log.Errorf("failed creating firewall manager: %s", err) - } - - if e.firewall != nil && e.firewall.IsServerRouteSupported() { - err = e.routeManager.EnableServerRouter(e.firewall) - if err != nil { - e.close() - return fmt.Errorf("enable server router: %w", err) + } else if e.firewall != nil { + if err := e.initFirewall(err); err != nil { + return err } } @@ -455,6 +451,41 @@ func (e *Engine) Start() error { return nil } +func (e *Engine) initFirewall(error) error { + if e.firewall.IsServerRouteSupported() { + if err := e.routeManager.EnableServerRouter(e.firewall); err != nil { + e.close() + return fmt.Errorf("enable server router: %w", err) + } + } + + if e.rpManager == nil || !e.config.RosenpassEnabled { + return nil + } + + rosenpassPort := e.rpManager.GetAddress().Port + port := manager.Port{Values: []int{rosenpassPort}} + + // this rule is static and will be torn down on engine down by the firewall manager + if _, err := e.firewall.AddPeerFiltering( + net.IP{0, 0, 0, 0}, + manager.ProtocolUDP, + nil, + &port, + manager.RuleDirectionIN, + manager.ActionAccept, + "", + "", + ); err != nil { + log.Errorf("failed to allow rosenpass interface traffic: %v", err) + return nil + } + + log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort) + + return nil +} + // modifyPeers updates peers that have been modified (e.g. IP address has been changed). // It closes the existing connection, removes it from the peerConns map, and creates a new one. func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { From 2bdb4cb44a8c128214eac0933d040f2e8447a92a Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 31 Dec 2024 18:59:37 +0300 Subject: [PATCH 07/10] [management] Preserve jwt groups when accessing API with PAT (#3128) * Skip JWT group sync for token-based authentication Signed-off-by: bcmmbaga * Add tests Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 6 ++++ management/server/account_test.go | 30 +++++++++++++++++-- .../server/http/middleware/auth_middleware.go | 1 + management/server/jwtclaims/extractor.go | 2 ++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index e60b41b4e..83a8759f9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1252,6 +1252,12 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { + if claim, exists := claims.Raw[jwtclaims.IsToken]; exists { + if isToken, ok := claim.(bool); ok && isToken { + return nil + } + } + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err diff --git a/management/server/account_test.go b/management/server/account_test.go index 280d998fd..2289c96f9 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2729,6 +2729,19 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") + t.Run("skip sync for token auth type", func(t *testing.T) { + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced") + }) + t.Run("empty jwt groups", func(t *testing.T) { claims := jwtclaims.AuthorizationClaims{ UserId: "user1", @@ -2822,7 +2835,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.Len(t, user.AutoGroups, 1, "new group should be added") }) - t.Run("remove all JWT groups", func(t *testing.T) { + t.Run("remove all JWT groups when list is empty", func(t *testing.T) { claims := jwtclaims.AuthorizationClaims{ UserId: "user1", Raw: jwt.MapClaims{"groups": []interface{}{}}, @@ -2833,7 +2846,20 @@ func TestAccount_SetJWTGroups(t *testing.T) { user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") - assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") + assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present") + }) + + t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) { + claims := jwtclaims.AuthorizationClaims{ + UserId: "user2", + Raw: jwt.MapClaims{}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed") }) } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 0d3459712..0a54cbaed 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -175,6 +175,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + claimMaps[jwtclaims.IsToken] = true jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint // Update the current request with the new context information. diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index c441650e9..18214b434 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -22,6 +22,8 @@ const ( LastLoginSuffix = "nb_last_login" // Invited claim indicates that an incoming JWT is from a user that just accepted an invitation Invited = "nb_invited" + // IsToken claim indicates that auth type from the user is a token + IsToken = "is_token" ) // ExtractClaims Extract function type From 18b049cd2439de3c7769e7666e86ed16254df477 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 31 Dec 2024 18:10:40 +0100 Subject: [PATCH 08/10] [management] remove sorting from network map generation (#3126) --- go.mod | 1 - go.sum | 2 - .../networks/resources/types/resource.go | 1 + management/server/types/account.go | 110 ++++++++---------- management/server/types/account_test.go | 30 ++--- management/server/types/policy.go | 21 ++-- route/route.go | 3 + 7 files changed, 79 insertions(+), 89 deletions(-) diff --git a/go.mod b/go.mod index 330d0763f..d48280df0 100644 --- a/go.mod +++ b/go.mod @@ -79,7 +79,6 @@ require ( github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/things-go/go-socks5 v0.0.4 - github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 diff --git a/go.sum b/go.sum index ea4597836..540cbf20b 100644 --- a/go.sum +++ b/go.sum @@ -698,8 +698,6 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 h1:S5h7yNKStqF8CqFtgtMNMzk/lUI3p82LrX6h2BhlsTM= -github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907/go.mod h1:/7Fy/4/OyrkguTf2i2pO4erUD/8QAlrlmXSdSJPu678= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 7eecdce0f..162f90378 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -111,6 +111,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network NetID: route.NetID(n.Name), Description: n.Description, Peer: peer.Key, + PeerID: peer.ID, PeerGroups: nil, Masquerade: router.Masquerade, Metric: router.Metric, diff --git a/management/server/types/account.go b/management/server/types/account.go index 3ef862fa6..f9e1cc9b4 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -13,7 +13,6 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" - "github.com/yourbasic/radix" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -304,55 +303,47 @@ func (a *Account) GetPeerNetworkMap( return nm } -func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer { - missingPeers := map[string]struct{}{} - for _, r := range networkResourcesRoutes { - if r.Peer == peer.Key { - continue - } +func (a *Account) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, + peer *nbpeer.Peer, + peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, + isRouter bool, + sourcePeers map[string]struct{}, +) []*nbpeer.Peer { - missing := true - for _, p := range slices.Concat(peersToConnect, expiredPeers) { - if r.Peer == p.Key { - missing = false - break - } - } - if missing { - missingPeers[r.Peer] = struct{}{} - } + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} } - if isRouter { - for _, s := range sourcePeers { - if s == peer.ID { - continue - } + delete(sourcePeers, peer.ID) - missing := true - for _, p := range slices.Concat(peersToConnect, expiredPeers) { - if s == p.ID { - missing = false - break - } - } - if missing { - p, ok := a.Peers[s] - if ok { - missingPeers[p.Key] = struct{}{} - } - } + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} } } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } for p := range missingPeers { - for _, p2 := range a.Peers { - if p2.Key == p { - peersToConnect = append(peersToConnect, p2) - break - } + if missingPeer := a.Peers[p]; missingPeer != nil { + peersToConnect = append(peersToConnect, missingPeer) } } + return peersToConnect } @@ -1156,7 +1147,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli } func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { - distPeersWithPolicy := make([]string, 0) + distPeersWithPolicy := make(map[string]struct{}) for _, id := range rule.Sources { group := a.Groups[id] if group == nil { @@ -1170,16 +1161,13 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID _, distPeer := distributionPeers[pID] _, valid := validatedPeersMap[pID] if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { - distPeersWithPolicy = append(distPeersWithPolicy, pID) + distPeersWithPolicy[pID] = struct{}{} } } } - radix.Sort(distPeersWithPolicy) - uniqueDistributionPeers := slices.Compact(distPeersWithPolicy) - - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(uniqueDistributionPeers)) - for _, pID := range uniqueDistributionPeers { + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { peer := a.Peers[pID] if peer == nil { continue @@ -1306,10 +1294,10 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { } // GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. -func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { +func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { var isRoutingPeer bool var routes []*route.Route - allSourcePeers := make([]string, 0) + allSourcePeers := make(map[string]struct{}, len(a.Peers)) for _, resource := range a.NetworkResources { var addSourcePeers bool @@ -1326,7 +1314,9 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st for _, policy := range resourcePolicies[resource.ID] { peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) if addSourcePeers { - allSourcePeers = append(allSourcePeers, a.getPostureValidPeers(peers, policy.SourcePostureChecks)...) + for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { + allSourcePeers[pID] = struct{}{} + } } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { // add routes for the resource if the peer is in the distribution group for peerId, router := range networkRoutingPeers { @@ -1340,8 +1330,7 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st } } - radix.Sort(allSourcePeers) - return isRoutingPeer, routes, slices.Compact(allSourcePeers) + return isRoutingPeer, routes, allSourcePeers } func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { @@ -1355,8 +1344,7 @@ func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []s } func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { - gObjs := make([]*Group, 0, len(groups)) - tp := 0 + peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity for _, groupID := range groups { group := a.GetGroup(groupID) if group == nil { @@ -1368,17 +1356,17 @@ func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []st return group.Peers } - gObjs = append(gObjs, group) - tp += len(group.Peers) + for _, peerID := range group.Peers { + peerIDs[peerID] = struct{}{} + } } - ids := make([]string, 0, tp) - for _, group := range gObjs { - ids = append(ids, group.Peers...) + ids := make([]string, 0, len(peerIDs)) + for peerID := range peerIDs { + ids = append(ids, peerID) } - radix.Sort(ids) - return slices.Compact(ids) + return ids } // getNetworkResources filters and returns a list of network resources associated with the given network ID. diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index efe930108..367baef4f 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -316,19 +316,19 @@ func Test_GetResourcePoliciesMap(t *testing.T) { func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { account := setupTestAccount() - peer := &nbpeer.Peer{Key: "peer1"} + peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"} networkResourcesRoutes := []*route.Route{ - {Peer: "peer2Key"}, - {Peer: "peer3Key"}, + {Peer: "peer2Key", PeerID: "peer2"}, + {Peer: "peer3Key", PeerID: "peer3"}, } peersToConnect := []*nbpeer.Peer{ - {Key: "peer2Key"}, + {Key: "peer2Key", ID: "peer2"}, } expiredPeers := []*nbpeer.Peer{ - {Key: "peer4Key"}, + {Key: "peer4Key", ID: "peer4"}, } - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 2) require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer3Key", result[1].Key) @@ -345,7 +345,7 @@ func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) { } expiredPeers := []*nbpeer.Peer{} - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 1) require.Equal(t, "peer2Key", result[0].Key) } @@ -364,7 +364,7 @@ func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) { {Key: "peer3Key"}, } - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 1) require.Equal(t, "peer2Key", result[0].Key) } @@ -376,7 +376,7 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { peersToConnect := []*nbpeer.Peer{} expiredPeers := []*nbpeer.Peer{} - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 0) } @@ -559,8 +559,8 @@ func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") - assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) @@ -599,7 +599,7 @@ func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) @@ -692,8 +692,8 @@ func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") - assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) @@ -741,7 +741,7 @@ func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 5b2cf06a0..17964ed1f 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -1,11 +1,5 @@ package types -import ( - "slices" - - "github.com/yourbasic/radix" -) - const ( // PolicyTrafficActionAccept indicates that the traffic is accepted PolicyTrafficActionAccept = PolicyTrafficActionType("accept") @@ -126,10 +120,17 @@ func (p *Policy) SourceGroups() []string { if len(p.Rules) == 1 { return p.Rules[0].Sources } - groups := make([]string, 0) + groups := make(map[string]struct{}, len(p.Rules)) for _, rule := range p.Rules { - groups = append(groups, rule.Sources...) + for _, source := range rule.Sources { + groups[source] = struct{}{} + } } - radix.Sort(groups) - return slices.Compact(groups) + + groupIDs := make([]string, 0, len(groups)) + for groupID := range groups { + groupIDs = append(groupIDs, groupID) + } + + return groupIDs } diff --git a/route/route.go b/route/route.go index 8f3c99b4c..ad2aaba89 100644 --- a/route/route.go +++ b/route/route.go @@ -95,6 +95,7 @@ type Route struct { NetID NetID Description string Peer string + PeerID string `gorm:"-"` PeerGroups []string `gorm:"serializer:json"` NetworkType NetworkType Masquerade bool @@ -120,6 +121,7 @@ func (r *Route) Copy() *Route { KeepRoute: r.KeepRoute, NetworkType: r.NetworkType, Peer: r.Peer, + PeerID: r.PeerID, PeerGroups: slices.Clone(r.PeerGroups), Metric: r.Metric, Masquerade: r.Masquerade, @@ -146,6 +148,7 @@ func (r *Route) IsEqual(other *Route) bool { other.KeepRoute == r.KeepRoute && other.NetworkType == r.NetworkType && other.Peer == r.Peer && + other.PeerID == r.PeerID && other.Metric == r.Metric && other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && From 03fd656344a3e65031bbfe8f5332dbcb522a1708 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 31 Dec 2024 18:45:40 +0100 Subject: [PATCH 09/10] [management] Fix policy tests (#3135) - Add firewall rule isEqual method - Fix tests --- management/server/policy_test.go | 16 +++++++++++----- management/server/types/firewall_rule.go | 9 +++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 0d17da23a..73fc6edba 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -76,7 +76,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, "GroupWorkstations": { ID: "GroupWorkstations", - Name: "All", + Name: "GroupWorkstations", Peers: []string{ "peerB", "peerA", @@ -280,10 +280,16 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, } assert.Len(t, firewallRules, len(epectedFirewallRules)) - slices.SortFunc(epectedFirewallRules, sortFunc()) - slices.SortFunc(firewallRules, sortFunc()) - for i := range firewallRules { - assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) + + for _, rule := range firewallRules { + contains := false + for _, expectedRule := range epectedFirewallRules { + if rule.IsEqual(expectedRule) { + contains = true + break + } + } + assert.True(t, contains, "rule not found in expected rules %#v", rule) } }) } diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go index 3d1b7e225..4e405152c 100644 --- a/management/server/types/firewall_rule.go +++ b/management/server/types/firewall_rule.go @@ -35,6 +35,15 @@ type FirewallRule struct { Port string } +// IsEqual checks if two firewall rules are equal. +func (r *FirewallRule) IsEqual(other *FirewallRule) bool { + return r.PeerIP == other.PeerIP && + r.Direction == other.Direction && + r.Action == other.Action && + r.Protocol == other.Protocol && + r.Port == other.Port +} + // generateRouteFirewallRules generates a list of firewall rules for a given route. func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { rulesExists := make(map[string]struct{}) From 782e3f8853f1ac455d600fbad16eddbbeebefbac Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 2 Jan 2025 13:51:01 +0100 Subject: [PATCH 10/10] [management] Add integration test for the setup-keys API endpoints (#2936) --- management/cmd/management.go | 4 +- management/server/account.go | 4 +- management/server/account_test.go | 42 - management/server/geolocation/geolocation.go | 39 +- .../server/geolocation/geolocation_test.go | 2 +- management/server/grpcserver.go | 4 +- management/server/http/handler.go | 40 +- .../handlers/policies/geolocations_handler.go | 6 +- .../handlers/policies/policies_handler.go | 2 +- .../policies/posture_checks_handler.go | 6 +- .../policies/posture_checks_handler_test.go | 2 +- .../handlers/setup_keys/setupkeys_handler.go | 8 +- .../setup_keys/setupkeys_handler_test.go | 13 +- .../setupkeys_handler_benchmark_test.go | 226 ++++ .../setupkeys_handler_integration_test.go | 1146 +++++++++++++++++ .../http/testing/testdata/setup_keys.sql | 24 + .../http/testing/testing_tools/tools.go | 307 +++++ management/server/integrated_validator.go | 43 + management/server/jwtclaims/jwtValidator.go | 39 +- management/server/management_test.go | 42 +- management/server/networks/manager.go | 27 + .../server/networks/resources/manager.go | 39 + management/server/setupkey.go | 4 +- 23 files changed, 1919 insertions(+), 150 deletions(-) create mode 100644 management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go create mode 100644 management/server/http/testing/integration/setupkeys_handler_integration_test.go create mode 100644 management/server/http/testing/testdata/setup_keys.sql create mode 100644 management/server/http/testing/testing_tools/tools.go diff --git a/management/cmd/management.go b/management/cmd/management.go index 4f34009b7..1c8fca8dc 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -42,7 +42,7 @@ import ( nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" - httpapi "github.com/netbirdio/netbird/management/server/http" + nbhttp "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -281,7 +281,7 @@ var ( routersManager := routers.NewManager(store, permissionsManager, accountManager) networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) - httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) + httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 83a8759f9..6c8205f26 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -161,7 +161,7 @@ type DefaultAccountManager struct { externalCacheManager ExternalCacheManager ctx context.Context eventStore activity.Store - geo *geolocation.Geolocation + geo geolocation.Geolocation requestBuffer *AccountRequestBuffer @@ -244,7 +244,7 @@ func BuildManager( singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, - geo *geolocation.Geolocation, + geo geolocation.Geolocation, userDeleteFromIDPEnabled bool, integratedPeerValidator integrated_validator.IntegratedValidator, metrics telemetry.AppMetrics, diff --git a/management/server/account_test.go b/management/server/account_test.go index 2289c96f9..4f6cdf78d 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -27,7 +27,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -38,47 +37,6 @@ import ( "github.com/netbirdio/netbird/route" ) -type MocIntegratedValidator struct { - ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) -} - -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { - return nil -} - -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { - if a.ValidatePeerFunc != nil { - return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) - } - return update, false, nil -} -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { - validatedPeers := make(map[string]struct{}) - for _, peer := range peers { - validatedPeers[peer.ID] = struct{}{} - } - return validatedPeers, nil -} - -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { - return peer -} - -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { - return false, false, nil -} - -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { - return nil -} - -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { - -} - -func (MocIntegratedValidator) Stop(_ context.Context) { -} - func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) { t.Helper() peer := &nbpeer.Peer{ diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index 553a31581..c0179a1c4 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -14,7 +14,14 @@ import ( log "github.com/sirupsen/logrus" ) -type Geolocation struct { +type Geolocation interface { + Lookup(ip net.IP) (*Record, error) + GetAllCountries() ([]Country, error) + GetCitiesByCountry(countryISOCode string) ([]City, error) + Stop() error +} + +type geolocationImpl struct { mmdbPath string mux sync.RWMutex db *maxminddb.Reader @@ -54,7 +61,7 @@ const ( geonamesdbPattern = "geonames_*.db" ) -func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) { +func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (Geolocation, error) { mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern) mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate) if err != nil { @@ -86,7 +93,7 @@ func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geol return nil, err } - geo := &Geolocation{ + geo := &geolocationImpl{ mmdbPath: mmdbPath, mux: sync.RWMutex{}, db: db, @@ -113,7 +120,7 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) { return db, nil } -func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { +func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) { gl.mux.RLock() defer gl.mux.RUnlock() @@ -127,7 +134,7 @@ func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { } // GetAllCountries retrieves a list of all countries. -func (gl *Geolocation) GetAllCountries() ([]Country, error) { +func (gl *geolocationImpl) GetAllCountries() ([]Country, error) { allCountries, err := gl.locationDB.GetAllCountries() if err != nil { return nil, err @@ -143,7 +150,7 @@ func (gl *Geolocation) GetAllCountries() ([]Country, error) { } // GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code. -func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) { +func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, error) { allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode) if err != nil { return nil, err @@ -158,7 +165,7 @@ func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) return cities, nil } -func (gl *Geolocation) Stop() error { +func (gl *geolocationImpl) Stop() error { close(gl.stopCh) if gl.db != nil { if err := gl.db.Close(); err != nil { @@ -259,3 +266,21 @@ func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile strin } return nil } + +type Mock struct{} + +func (g *Mock) Lookup(ip net.IP) (*Record, error) { + return &Record{}, nil +} + +func (g *Mock) GetAllCountries() ([]Country, error) { + return []Country{}, nil +} + +func (g *Mock) GetCitiesByCountry(countryISOCode string) ([]City, error) { + return []City{}, nil +} + +func (g *Mock) Stop() error { + return nil +} diff --git a/management/server/geolocation/geolocation_test.go b/management/server/geolocation/geolocation_test.go index 9bdefd268..fecd715be 100644 --- a/management/server/geolocation/geolocation_test.go +++ b/management/server/geolocation/geolocation_test.go @@ -24,7 +24,7 @@ func TestGeoLite_Lookup(t *testing.T) { db, err := openDB(filename) assert.NoError(t, err) - geo := &Geolocation{ + geo := &geolocationImpl{ mux: sync.RWMutex{}, db: db, stopCh: make(chan struct{}), diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 2635ac11b..daa23d2ab 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -38,7 +38,7 @@ type GRPCServer struct { peersUpdateManager *PeersUpdateManager config *Config secretsManager SecretsManager - jwtValidator *jwtclaims.JWTValidator + jwtValidator jwtclaims.JWTValidator jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager @@ -61,7 +61,7 @@ func NewServer( return nil, err } - var jwtValidator *jwtclaims.JWTValidator + var jwtValidator jwtclaims.JWTValidator if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { jwtValidator, err = jwtclaims.NewJWTValidator( diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 7db7ab5b8..cc2ad00b7 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -35,15 +35,8 @@ import ( const apiPrefix = "/api" -type apiHandler struct { - Router *mux.Router - AccountManager s.AccountManager - geolocationManager *geolocation.Geolocation - AuthCfg configs.AuthCfg -} - -// APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. +func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -78,27 +71,20 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, networksMa router := rootRouter.PathPrefix(prefix).Subrouter() router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) - api := apiHandler{ - Router: router, - AccountManager: accountManager, - geolocationManager: LocationManager, - AuthCfg: authCfg, - } - - if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } - accounts.AddEndpoints(api.AccountManager, authCfg, router) - peers.AddEndpoints(api.AccountManager, authCfg, router) - users.AddEndpoints(api.AccountManager, authCfg, router) - setup_keys.AddEndpoints(api.AccountManager, authCfg, router) - policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router) - groups.AddEndpoints(api.AccountManager, authCfg, router) - routes.AddEndpoints(api.AccountManager, authCfg, router) - dns.AddEndpoints(api.AccountManager, authCfg, router) - events.AddEndpoints(api.AccountManager, authCfg, router) - networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router) + accounts.AddEndpoints(accountManager, authCfg, router) + peers.AddEndpoints(accountManager, authCfg, router) + users.AddEndpoints(accountManager, authCfg, router) + setup_keys.AddEndpoints(accountManager, authCfg, router) + policies.AddEndpoints(accountManager, LocationManager, authCfg, router) + groups.AddEndpoints(accountManager, authCfg, router) + routes.AddEndpoints(accountManager, authCfg, router) + dns.AddEndpoints(accountManager, authCfg, router) + events.AddEndpoints(accountManager, authCfg, router) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router) return rootRouter, nil } diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index e5bf3e695..161d97402 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -22,18 +22,18 @@ var ( // geolocationsHandler is a handler that returns locations. type geolocationsHandler struct { accountManager server.AccountManager - geolocationManager *geolocation.Geolocation + geolocationManager geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { +func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") } // newGeolocationsHandlerHandler creates a new Geolocations handler -func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { +func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { return &geolocationsHandler{ accountManager: accountManager, geolocationManager: geolocationManager, diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index b1035c570..a748e73b8 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -23,7 +23,7 @@ type handler struct { claimsExtractor *jwtclaims.ClaimsExtractor } -func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { +func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { policiesHandler := newHandler(accountManager, authCfg) router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index 44917605b..ce0d4878c 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -19,11 +19,11 @@ import ( // postureChecksHandler is a handler that returns posture checks of the account. type postureChecksHandler struct { accountManager server.AccountManager - geolocationManager *geolocation.Geolocation + geolocationManager geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { +func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") @@ -34,7 +34,7 @@ func addPostureCheckEndpoint(accountManager server.AccountManager, locationManag } // newPostureChecksHandler creates a new PostureChecks handler -func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { +func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { return &postureChecksHandler{ accountManager: accountManager, geolocationManager: geolocationManager, diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index e9a539e45..237687fd4 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -70,7 +70,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH return claims.AccountId, claims.UserId, nil }, }, - geolocationManager: &geolocation.Geolocation{}, + geolocationManager: &geolocation.Mock{}, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 89696a165..a627d7203 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -93,7 +93,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { return } - apiSetupKeys := toResponseBody(setupKey) + apiSetupKeys := ToResponseBody(setupKey) // for the creation we need to send the plain key apiSetupKeys.Key = setupKey.Key @@ -183,7 +183,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { apiSetupKeys := make([]*api.SetupKey, 0) for _, key := range setupKeys { - apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) + apiSetupKeys = append(apiSetupKeys, ToResponseBody(key)) } util.WriteJSONObject(r.Context(), w, apiSetupKeys) @@ -216,14 +216,14 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) - err := json.NewEncoder(w).Encode(toResponseBody(key)) + err := json.NewEncoder(w).Encode(ToResponseBody(key)) if err != nil { util.WriteError(ctx, err, w) return } } -func toResponseBody(key *types.SetupKey) *api.SetupKey { +func ToResponseBody(key *types.SetupKey) *api.SetupKey { var state string switch { case key.IsExpired(): diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 4ecb1e9ed..f56227c10 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -26,7 +26,6 @@ const ( newSetupKeyName = "New Setup Key" updatedSetupKeyName = "KKKey" notFoundSetupKeyID = "notFoundSetupKeyID" - testAccountID = "test_id" ) func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey, @@ -81,7 +80,7 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe return jwtclaims.AuthorizationClaims{ UserId: user.Id, Domain: "hotmail.com", - AccountId: testAccountID, + AccountId: "testAccountId", } }), ), @@ -102,7 +101,7 @@ func TestSetupKeysHandlers(t *testing.T) { updatedDefaultSetupKey.Name = updatedSetupKeyName updatedDefaultSetupKey.Revoked = true - expectedNewKey := toResponseBody(newSetupKey) + expectedNewKey := ToResponseBody(newSetupKey) expectedNewKey.Key = plainKey tt := []struct { name string @@ -120,7 +119,7 @@ func TestSetupKeysHandlers(t *testing.T) { requestPath: "/api/setup-keys", expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKeys: []*api.SetupKey{toResponseBody(defaultSetupKey)}, + expectedSetupKeys: []*api.SetupKey{ToResponseBody(defaultSetupKey)}, }, { name: "Get Existing Setup Key", @@ -128,7 +127,7 @@ func TestSetupKeysHandlers(t *testing.T) { requestPath: "/api/setup-keys/" + existingSetupKeyID, expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(defaultSetupKey), + expectedSetupKey: ToResponseBody(defaultSetupKey), }, { name: "Get Not Existing Setup Key", @@ -159,7 +158,7 @@ func TestSetupKeysHandlers(t *testing.T) { ))), expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(updatedDefaultSetupKey), + expectedSetupKey: ToResponseBody(updatedDefaultSetupKey), }, { name: "Delete Setup Key", @@ -228,7 +227,7 @@ func TestSetupKeysHandlers(t *testing.T) { func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) { t.Helper() // this comparison is done manually because when converting to JSON dates formatted differently - // assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work + // assert.Equal(t, got.UpdatedAt, tc.expectedResponse.UpdatedAt) //doesn't work assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "") assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "") assert.Equal(t, got.Name, expected.Name) diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go new file mode 100644 index 000000000..5e2895bcc --- /dev/null +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -0,0 +1,226 @@ +package benchmarks + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +// Map to store peers, groups, users, and setupKeys by name +var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{ + "Setup Keys - XS": {Peers: 10000, Groups: 10000, Users: 10000, SetupKeys: 5}, + "Setup Keys - S": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 100}, + "Setup Keys - M": {Peers: 100, Groups: 20, Users: 20, SetupKeys: 1000}, + "Setup Keys - L": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 5000}, + "Peers - L": {Peers: 10000, Groups: 5, Users: 5, SetupKeys: 5000}, + "Groups - L": {Peers: 5, Groups: 10000, Users: 5, SetupKeys: 5000}, + "Users - L": {Peers: 5, Groups: 5, Users: 10000, SetupKeys: 5000}, + "Setup Keys - XL": {Peers: 500, Groups: 50, Users: 100, SetupKeys: 25000}, +} + +func BenchmarkCreateSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + requestBody := api.CreateSetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName + strconv.Itoa(i), + Type: "reusable", + UsageLimit: 0, + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPost, "/api/setup-keys", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkUpdateSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + groupId := testing_tools.TestGroupId + if i%2 == 0 { + groupId = testing_tools.NewGroupId + } + requestBody := api.SetupKeyRequest{ + AutoGroups: []string{groupId}, + Revoked: false, + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetOneSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetAllSetupKeys(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15}, + "Setup Keys - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 40}, + "Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Peers - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Groups - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Users - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Setup Keys - XL": {MinMsPerOpLocal: 140, MaxMsPerOpLocal: 220, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 500}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkDeleteSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/setup-keys/"+"oldkey-"+strconv.Itoa(i), testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go new file mode 100644 index 000000000..193c0fb02 --- /dev/null +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -0,0 +1,1146 @@ +package integration + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +func Test_SetupKeys_Create(t *testing.T) { + truePointer := true + + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestBody *api.CreateSetupKeyRequest + requestType string + requestPath string + userId string + }{ + { + name: "Create Setup Key", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 0, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with already existing name", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.ExistingKeyName, + Type: "one-off", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key as on-off with more than one usage", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "one-off", + UsageLimit: 3, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with expiration in the past", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: -testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "one-off", + UsageLimit: 0, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Create Setup Key with AutoGroups that do exist", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 1, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key for ephemeral peers", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + Ephemeral: &truePointer, + UsageLimit: 1, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with AutoGroups that do not exist", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{"someGroupID"}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Create Setup Key", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 0, + UsedTimes: 0, + Valid: true, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + req := testing_tools.BuildRequest(t, body, tc.requestType, tc.requestPath, user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestBody *api.SetupKeyRequest + requestType string + requestPath string + requestId string + }{ + { + name: "Add existing Group to existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Add non-existing Group to existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, "someGroupId"}, + Revoked: false, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Add existing Group to non-existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + { + name: "Remove existing Group from existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{}, + Revoked: false, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Remove existing Group to non-existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: "someID", + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{}, + Revoked: false, + }, + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + { + name: "Revoke existing valid Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Revoke existing revoked Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Un-Revoke existing revoked Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Revoke existing expired Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "expired", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Get(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestType string + requestPath string + requestId string + }{ + { + name: "Get existing valid Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Get existing expired Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + { + name: "Get existing revoked Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Get non-existing Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectRespnose := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectRespnose { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse []*api.SetupKey + requestType string + requestPath string + }{ + { + name: "Get all Setup Keys", + requestType: http.MethodGet, + requestPath: "/api/setup-keys", + expectedStatus: http.StatusOK, + expectedResponse: []*api.SetupKey{ + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := []api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + sort.Slice(got, func(i, j int) bool { + return got[i].UsageLimit < got[j].UsageLimit + }) + + sort.Slice(tc.expectedResponse, func(i, j int) bool { + return tc.expectedResponse[i].UsageLimit < tc.expectedResponse[j].UsageLimit + }) + + for i, _ := range tc.expectedResponse { + validateCreatedKey(t, tc.expectedResponse[i], &got[i]) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got[i].Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse[i], setup_keys.ToResponseBody(key)) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestType string + requestPath string + requestId string + }{ + { + name: "Delete existing valid Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Delete existing expired Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + { + name: "Delete existing revoked Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Delete non-existing Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + _, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + assert.Errorf(t, err, "Expected error when trying to get deleted key") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func validateCreatedKey(t *testing.T, expectedKey *api.SetupKey, got *api.SetupKey) { + t.Helper() + + if got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(testing_tools.ExpiresIn*time.Second)) || + got.Expires.After(time.Date(2300, 01, 01, 0, 0, 0, 0, time.Local)) || + got.Expires.Before(time.Date(1950, 01, 01, 0, 0, 0, 0, time.Local)) { + got.Expires = time.Time{} + expectedKey.Expires = time.Time{} + } + + if got.Id == "" { + t.Fatalf("Expected key to have an ID") + } + got.Id = "" + + if got.Key == "" { + t.Fatalf("Expected key to have a key") + } + got.Key = "" + + if got.UpdatedAt.After(time.Now().Add(-1*time.Minute)) && got.UpdatedAt.Before(time.Now().Add(+1*time.Minute)) { + got.UpdatedAt = time.Time{} + expectedKey.UpdatedAt = time.Time{} + } + + expectedKey.UpdatedAt = expectedKey.UpdatedAt.In(time.UTC) + got.UpdatedAt = got.UpdatedAt.In(time.UTC) + + assert.Equal(t, expectedKey, got) +} diff --git a/management/server/http/testing/testdata/setup_keys.sql b/management/server/http/testing/testdata/setup_keys.sql new file mode 100644 index 000000000..a315ea0f7 --- /dev/null +++ b/management/server/http/testing/testdata/setup_keys.sql @@ -0,0 +1,24 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); + + +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1); + diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go new file mode 100644 index 000000000..da910c5c3 --- /dev/null +++ b/management/server/http/testing/testing_tools/tools.go @@ -0,0 +1,307 @@ +package testing_tools + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" + nbhttp "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" +) + +const ( + TestAccountId = "testAccountId" + TestPeerId = "testPeerId" + TestGroupId = "testGroupId" + TestKeyId = "testKeyId" + + TestUserId = "testUserId" + TestAdminId = "testAdminId" + TestOwnerId = "testOwnerId" + TestServiceUserId = "testServiceUserId" + TestServiceAdminId = "testServiceAdminId" + BlockedUserId = "blockedUserId" + OtherUserId = "otherUserId" + InvalidToken = "invalidToken" + + NewKeyName = "newKey" + NewGroupId = "newGroupId" + ExpiresIn = 3600 + RevokedKeyId = "revokedKeyId" + ExpiredKeyId = "expiredKeyId" + + ExistingKeyName = "existingKey" +) + +type TB interface { + Cleanup(func()) + Helper() + Errorf(format string, args ...any) + Fatalf(format string, args ...any) + TempDir() string +} + +// BenchmarkCase defines a single benchmark test case +type BenchmarkCase struct { + Peers int + Groups int + Users int + SetupKeys int +} + +// PerformanceMetrics holds the performance expectations +type PerformanceMetrics struct { + MinMsPerOpLocal float64 + MaxMsPerOpLocal float64 + MinMsPerOpCICD float64 + MaxMsPerOpCICD float64 +} + +func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage) (http.Handler, server.AccountManager, chan struct{}) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) + if err != nil { + t.Fatalf("Failed to create test store: %v", err) + } + t.Cleanup(cleanup) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatalf("Failed to create metrics: %v", err) + } + + peersUpdateManager := server.NewPeersUpdateManager(nil) + updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId) + done := make(chan struct{}) + go func() { + if expectedPeerUpdate != nil { + peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) + } else { + peerShouldNotReceiveUpdate(t, updMsg) + } + close(done) + }() + + geoMock := &geolocation.Mock{} + validatorMock := server.MocIntegratedValidator{} + am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + networksManagerMock := networks.NewManagerMock() + resourcesManagerMock := resources.NewManagerMock() + routersManagerMock := routers.NewManagerMock() + groupsManagerMock := groups.NewManagerMock() + apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock) + if err != nil { + t.Fatalf("Failed to create API handler: %v", err) + } + + return apiHandler, am, done +} + +func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + assert.Equal(t, expected, msg) + case <-time.After(500 * time.Millisecond): + t.Errorf("Timed out waiting for update message") + } +} + +func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request { + t.Helper() + + req := httptest.NewRequest(requestType, requestPath, bytes.NewBuffer(requestBody)) + req.Header.Set("Authorization", "Bearer "+user) + + return req +} + +func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int, expectResponse bool) ([]byte, bool) { + t.Helper() + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if !expectResponse { + return nil, false + } + + if status := recorder.Code; status != expectedStatus { + t.Fatalf("handler returned wrong status code: got %v want %v, content: %s", + status, expectedStatus, string(content)) + } + + return content, expectedStatus == http.StatusOK +} + +func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) { + b.Helper() + + ctx := context.Background() + account, err := am.GetAccount(ctx, TestAccountId) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + + // Create peers + for i := 0; i < peers; i++ { + peerKey, _ := wgtypes.GeneratePrivateKey() + peer := &nbpeer.Peer{ + ID: fmt.Sprintf("oldpeer-%d", i), + DNSLabel: fmt.Sprintf("oldpeer-%d", i), + Key: peerKey.PublicKey().String(), + IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + Status: &nbpeer.PeerStatus{}, + UserID: TestUserId, + } + account.Peers[peer.ID] = peer + } + + // Create users + for i := 0; i < users; i++ { + user := &types.User{ + Id: fmt.Sprintf("olduser-%d", i), + AccountID: account.Id, + Role: types.UserRoleUser, + } + account.Users[user.Id] = user + } + + for i := 0; i < setupKeys; i++ { + key := &types.SetupKey{ + Id: fmt.Sprintf("oldkey-%d", i), + AccountID: account.Id, + AutoGroups: []string{"someGroupID"}, + ExpiresAt: time.Now().Add(ExpiresIn * time.Second), + Name: NewKeyName + strconv.Itoa(i), + Type: "reusable", + UsageLimit: 0, + } + account.SetupKeys[key.Id] = key + } + + // Create groups and policies + account.Policies = make([]*types.Policy, 0, groups) + for i := 0; i < groups; i++ { + groupID := fmt.Sprintf("group-%d", i) + group := &types.Group{ + ID: groupID, + Name: fmt.Sprintf("Group %d", i), + } + for j := 0; j < peers/groups; j++ { + peerIndex := i*(peers/groups) + j + group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) + } + account.Groups[groupID] = group + + // Create a policy for this group + policy := &types.Policy{ + ID: fmt.Sprintf("policy-%d", i), + Name: fmt.Sprintf("Policy for Group %d", i), + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: fmt.Sprintf("rule-%d", i), + Name: fmt.Sprintf("Rule for Group %d", i), + Enabled: true, + Sources: []string{groupID}, + Destinations: []string{groupID}, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + account.Policies = append(account.Policies, policy) + } + + account.PostureChecks = []*posture.Checks{ + { + ID: "PostureChecksAll", + Name: "All", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", + }, + }, + }, + } + + err = am.Store.SaveAccount(context.Background(), account) + if err != nil { + b.Fatalf("Failed to save account: %v", err) + } + +} + +func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration, perfMetrics PerformanceMetrics, recorder *httptest.ResponseRecorder) { + b.Helper() + + if recorder.Code != http.StatusOK { + b.Fatalf("Benchmark %s failed: unexpected status code %d", name, recorder.Code) + } + + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + minExpected := perfMetrics.MinMsPerOpLocal + maxExpected := perfMetrics.MaxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = perfMetrics.MinMsPerOpCICD + maxExpected = perfMetrics.MaxMsPerOpCICD + } + + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", name, msPerOp, minExpected) + } + + if msPerOp > maxExpected { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected) + } +} diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 47c4ca6ae..62e9213f7 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" ) @@ -78,3 +79,45 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) } + +type MocIntegratedValidator struct { + ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) +} + +func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { + if a.ValidatePeerFunc != nil { + return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) + } + return update, false, nil +} +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for _, peer := range peers { + validatedPeers[peer.ID] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { + return false, false, nil +} + +func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + // just a dummy +} + +func (MocIntegratedValidator) Stop(_ context.Context) { + // just a dummy +} diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index b91616fa5..79e59e76f 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -72,15 +72,19 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } -// JWTValidator struct to handle token validation and parsing -type JWTValidator struct { +type JWTValidator interface { + ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) +} + +// jwtValidatorImpl struct to handle token validation and parsing +type jwtValidatorImpl struct { options Options } var keyNotFound = errors.New("unable to find appropriate key") // NewJWTValidator constructor -func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { +func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) { keys, err := getPemKeys(ctx, keysLocation) if err != nil { return nil, err @@ -146,13 +150,13 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, options.UserProperty = "user" } - return &JWTValidator{ + return &jwtValidatorImpl{ options: options, }, nil } // ValidateAndParse validates the token and returns the parsed token -func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { +func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { // If the token is empty... if token == "" { // Check if it was required @@ -318,3 +322,28 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { return 0 } + +type JwtValidatorMock struct{} + +func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { + claimMaps := jwt.MapClaims{} + + switch token { + case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": + claimMaps[UserIDClaim] = token + claimMaps[AccountIDSuffix] = "testAccountId" + claimMaps[DomainIDSuffix] = "test.com" + claimMaps[DomainCategorySuffix] = "private" + case "otherUserId": + claimMaps[UserIDClaim] = "otherUserId" + claimMaps[AccountIDSuffix] = "otherAccountId" + claimMaps[DomainIDSuffix] = "other.com" + claimMaps[DomainCategorySuffix] = "private" + case "invalidToken": + return nil, errors.New("invalid token") + } + + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) + return jwtToken, nil +} + diff --git a/management/server/management_test.go b/management/server/management_test.go index 40514ae14..cfa2c138f 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -21,13 +21,10 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -448,43 +445,6 @@ var _ = Describe("Management service", func() { }) }) -type MocIntegratedValidator struct { -} - -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { - return nil -} - -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { - return update, false, nil -} - -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { - validatedPeers := make(map[string]struct{}) - for p := range peers { - validatedPeers[p] = struct{}{} - } - return validatedPeers, nil -} - -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { - return peer -} - -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { - return false, false, nil -} - -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { - return nil -} - -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { - -} - -func (MocIntegratedValidator) Stop(_ context.Context) {} - func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -547,7 +507,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. log.Fatalf("failed creating metrics: %v", err) } - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, server.MocIntegratedValidator{}, metrics) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index 4a7b3db77..51205f1e9 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -32,6 +32,9 @@ type managerImpl struct { routersManager routers.Manager } +type mockManager struct { +} + func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, @@ -185,3 +188,27 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw return nil } + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) { + return []*types.Network{}, nil +} + +func (m *mockManager) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + return network, nil +} + +func (m *mockManager) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { + return &types.Network{}, nil +} + +func (m *mockManager) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + return network, nil +} + +func (m *mockManager) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { + return nil +} diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 0fff5bcf8..02b462947 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -34,6 +34,9 @@ type managerImpl struct { accountManager s.AccountManager } +type mockManager struct { +} + func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, @@ -381,3 +384,39 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti return eventsToStore, nil } + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) { + return []*types.NetworkResource{}, nil +} + +func (m *mockManager) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { + return []*types.NetworkResource{}, nil +} + +func (m *mockManager) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { + return map[string][]string{}, nil +} + +func (m *mockManager) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { + return nil +} + +func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) { + return []func(){}, nil +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 9a4a1efb8..f2f1aad45 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -75,7 +75,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { - return err + return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) @@ -132,7 +132,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { - return err + return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)