diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 55043c5b2..2db9ea642 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -169,6 +169,10 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p log.Debugf("registering handler %s with priority %d", handler, priority) for _, domain := range domains { + if domain == "" { + log.Warn("skipping empty domain") + continue + } s.handlerChain.AddHandler(domain, handler, priority, nil) s.handlerPriorities[domain] = priority s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) @@ -188,6 +192,10 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) { // Only deregister from service if no handlers remain if !s.handlerChain.HasHandlers(domain) { + if domain == "" { + log.Warn("skipping empty domain") + continue + } s.service.DeregisterMux(nbdns.NormalizeZone(domain)) } } diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index f886a54d2..f1eef0055 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -27,7 +27,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSFor return &DNSForwarder{ listenAddress: listenAddress, ttl: ttl, - domains: domains, + domains: filterDomains(domains), } } @@ -35,10 +35,6 @@ func (f *DNSForwarder) Listen() error { log.Infof("listen DNS forwarder on address=%s", f.listenAddress) mux := dns.NewServeMux() - for _, d := range f.domains { - mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery) - } - dnsServer := &dns.Server{ Addr: f.listenAddress, Net: "udp", @@ -54,10 +50,11 @@ func (f *DNSForwarder) UpdateDomains(domains []string) { f.mux.HandleRemove(d) } - for _, d := range f.domains { - f.mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery) + newDomains := filterDomains(domains) + for _, d := range newDomains { + f.mux.HandleFunc(d, f.handleDNSQuery) } - f.domains = domains + f.domains = newDomains } func (f *DNSForwarder) Close(ctx context.Context) error { @@ -141,3 +138,16 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { log.Errorf("failed to write DNS response: %v", err) } } + +// filterDomains returns a list of normalized domains +func filterDomains(domains []string) []string { + newDomains := make([]string, 0, len(domains)) + for _, d := range domains { + if d == "" { + log.Warn("empty domain in DNS forwarder") + continue + } + newDomains = append(newDomains, nbdns.NormalizeZone(d)) + } + return newDomains +} diff --git a/client/internal/engine.go b/client/internal/engine.go index b6fae5b2b..9724e2a22 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -808,12 +808,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } routedDomains, routes := toRoutes(networkMap.GetRoutes()) + e.updateDNSForwarder(dnsRouteFeatureFlag, routedDomains) + if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - e.updateDNSForwarder(dnsRouteFeatureFlag, routedDomains) - log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) e.updateOfflinePeers(networkMap.GetOfflinePeers()) diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 28bf20d5f..10cb03f1d 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -67,7 +67,6 @@ func (d *DnsInterceptor) AddRoute(context.Context) error { func (d *DnsInterceptor) RemoveRoute() error { d.mu.Lock() - defer d.mu.Unlock() var merr *multierror.Error for domain, prefixes := range d.interceptedDomains { @@ -89,6 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error { } clear(d.interceptedDomains) + d.mu.Unlock() d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute) diff --git a/management/cmd/management.go b/management/cmd/management.go index 3eb52eb90..4f34009b7 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -276,10 +276,10 @@ var ( userManager := users.NewManager(store) settingsManager := settings.NewManager(store) permissionsManager := permissions.NewManager(userManager, settingsManager) - resourcesManager := resources.NewManager(store, permissionsManager, accountManager) + groupsManager := groups.NewManager(store, permissionsManager, accountManager) + resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager) routersManager := routers.NewManager(store, permissionsManager, accountManager) - networksManager := networks.NewManager(store, permissionsManager, resourcesManager) - groupsManager := groups.NewManager(store, permissionsManager) + networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 2165eba9c..5379a8dd8 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -154,6 +154,21 @@ const ( AccountRoutingPeerDNSResolutionEnabled Activity = 71 AccountRoutingPeerDNSResolutionDisabled Activity = 72 + + NetworkCreated Activity = 73 + NetworkUpdated Activity = 74 + NetworkDeleted Activity = 75 + + NetworkResourceCreated Activity = 76 + NetworkResourceUpdated Activity = 77 + NetworkResourceDeleted Activity = 78 + + NetworkRouterCreated Activity = 79 + NetworkRouterUpdated Activity = 80 + NetworkRouterDeleted Activity = 81 + + ResourceAddedToGroup Activity = 82 + ResourceRemovedFromGroup Activity = 83 ) var activityMap = map[Activity]Code{ @@ -234,6 +249,21 @@ var activityMap = map[Activity]Code{ AccountRoutingPeerDNSResolutionEnabled: {"Account routing peer DNS resolution enabled", "account.setting.routing.peer.dns.resolution.enable"}, AccountRoutingPeerDNSResolutionDisabled: {"Account routing peer DNS resolution disabled", "account.setting.routing.peer.dns.resolution.disable"}, + + NetworkCreated: {"Network created", "network.create"}, + NetworkUpdated: {"Network updated", "network.update"}, + NetworkDeleted: {"Network deleted", "network.delete"}, + + NetworkResourceCreated: {"Network resource created", "network.resource.create"}, + NetworkResourceUpdated: {"Network resource updated", "network.resource.update"}, + NetworkResourceDeleted: {"Network resource deleted", "network.resource.delete"}, + + NetworkRouterCreated: {"Network router created", "network.router.create"}, + NetworkRouterUpdated: {"Network router updated", "network.router.update"}, + NetworkRouterDeleted: {"Network router deleted", "network.router.delete"}, + + ResourceAddedToGroup: {"Resource added to group", "resource.group.add"}, + ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go index 905277064..1162348bd 100644 --- a/management/server/groups/manager.go +++ b/management/server/groups/manager.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/store" @@ -12,18 +14,26 @@ import ( type Manager interface { GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) + GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error + AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID string, resourceID *types.Resource) (func(), error) + RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID, resourceID string) (func(), error) } type managerImpl struct { store store.Store permissionsManager permissions.Manager + accountManager s.AccountManager } -func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { +type mockManager struct { +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, permissionsManager: permissionsManager, + accountManager: accountManager, } } @@ -58,7 +68,44 @@ func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, return err } - return m.store.AddResourceToGroup(ctx, accountID, groupID, resource) + event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, groupID, resource) + if err != nil { + return fmt.Errorf("error adding resource to group: %w", err) + } + + event() + + return nil +} + +func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID string, resource *types.Resource) (func(), error) { + err := transaction.AddResourceToGroup(ctx, accountID, groupID, resource) + if err != nil { + return nil, fmt.Errorf("error adding resource to group: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, accountID, groupID, accountID, activity.ResourceAddedToGroup, nil) + } + + return event, nil +} + +func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID, resourceID string) (func(), error) { + err := transaction.RemoveResourceFromGroup(ctx, accountID, groupID, resourceID) + if err != nil { + return nil, fmt.Errorf("error removing resource from group: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, accountID, groupID, accountID, activity.ResourceRemovedFromGroup, nil) + } + + return event, nil +} + +func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID) } func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum { @@ -97,3 +144,31 @@ func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum } return groupsInfo } + +func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + return map[string]*types.Group{}, nil +} + +func (m *mockManager) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + return []*types.Group{}, nil +} + +func (m *mockManager) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error { + return nil +} + +func (m *mockManager) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID string, resourceID *types.Resource) (func(), error) { + return func() { + // noop + }, nil +} + +func (m *mockManager) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID, resourceID string) (func(), error) { + return func() { + // noop + }, nil +} + +func NewManagerMock() Manager { + return &mockManager{} +} diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 932d4bfdd..351976baf 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -1255,11 +1255,18 @@ components: items: type: string example: ch8i4ug6lnn4g9hqv7m1 + policies: + description: List of policy IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m2 required: - id - routers - resources - routing_peers_count + - policies - $ref: '#/components/schemas/NetworkRequest' NetworkResourceMinimum: type: object diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 80aff514a..40574d6f1 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -530,6 +530,9 @@ type Network struct { // Name Network name Name string `json:"name"` + // Policies List of policy IDs associated with the network + Policies []string `json:"policies"` + // Resources List of network resource IDs associated with the network Resources []string `json:"resources"` @@ -551,7 +554,7 @@ type NetworkRequest struct { // NetworkResource defines model for NetworkResource. type NetworkResource struct { - // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com) + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) Address string `json:"address"` // Description Network resource description @@ -572,7 +575,7 @@ type NetworkResource struct { // NetworkResourceMinimum defines model for NetworkResourceMinimum. type NetworkResourceMinimum struct { - // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com) + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) Address string `json:"address"` // Description Network resource description @@ -584,7 +587,7 @@ type NetworkResourceMinimum struct { // NetworkResourceRequest defines model for NetworkResourceRequest. type NetworkResourceRequest struct { - // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com) + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) Address string `json:"address"` // Description Network resource description diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 3a169da9d..7db7ab5b8 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -98,7 +98,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, networksMa routes.AddEndpoints(api.AccountManager, authCfg, router) dns.AddEndpoints(api.AccountManager, authCfg, router) events.AddEndpoints(api.AccountManager, authCfg, router) - networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager.GetAccountIDFromToken, authCfg, router) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router) return rootRouter, nil } diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index 9875b139c..6b36a8fce 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -8,6 +8,7 @@ import ( "github.com/gorilla/mux" + s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/configs" @@ -27,17 +28,18 @@ type handler struct { networksManager networks.Manager resourceManager resources.Manager routerManager routers.Manager + accountManager s.AccountManager groupsManager groups.Manager extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) claimsExtractor *jwtclaims.ClaimsExtractor } -func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { +func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { addRouterEndpoints(routerManager, extractFromToken, authCfg, router) addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router) - networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, extractFromToken, authCfg) + networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg) router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") @@ -45,12 +47,13 @@ func AddEndpoints(networksManager networks.Manager, resourceManager resources.Ma router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") } -func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { +func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { return &handler{ networksManager: networksManager, resourceManager: resourceManager, routerManager: routerManager, groupsManager: groupsManager, + accountManager: accountManager, extractFromToken: extractFromToken, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -91,7 +94,13 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups)) + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account)) } func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { @@ -119,7 +128,15 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0)) + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(network.ID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs)) } func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { @@ -149,7 +166,15 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount)) + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(networkID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) } func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { @@ -191,7 +216,15 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount)) + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(networkID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) } func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { @@ -256,11 +289,12 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne return routerIDs, resourceIDs, peerCounter, nil } -func (h *handler) generateNetworkResponse(networks []*types.Network, routers map[string][]*routerTypes.NetworkRouter, resourceIDs map[string][]string, groups map[string]*nbtypes.Group) []*api.Network { +func (h *handler) generateNetworkResponse(networks []*types.Network, routers map[string][]*routerTypes.NetworkRouter, resourceIDs map[string][]string, groups map[string]*nbtypes.Group, account *nbtypes.Account) []*api.Network { var networkResponse []*api.Network for _, network := range networks { routerIDs, peerCounter := getRouterIDs(network, routers, groups) - networkResponse = append(networkResponse, network.ToAPIResponse(routerIDs, resourceIDs[network.ID], peerCounter)) + policyIDs := account.GetPoliciesAppliedInNetwork(network.ID) + networkResponse = append(networkResponse, network.ToAPIResponse(routerIDs, resourceIDs[network.ID], peerCounter, policyIDs)) } return networkResponse } diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index 53fbb7b93..a0dc9a10d 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -14,7 +14,6 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources/types" - nbtypes "github.com/netbirdio/netbird/management/server/types" ) type resourceHandler struct { @@ -130,18 +129,6 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) return } - res := nbtypes.Resource{ - ID: resource.ID, - Type: resource.Type.String(), - } - for _, groupID := range req.Groups { - err = h.groupsManager.AddResourceToGroup(r.Context(), accountID, userID, groupID, &res) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) @@ -205,18 +192,6 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) return } - res := nbtypes.Resource{ - ID: resource.ID, - Type: resource.Type.String(), - } - for _, groupID := range req.Groups { - err = h.groupsManager.AddResourceToGroup(r.Context(), accountID, userID, groupID, &res) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index d5291d9da..cc7b546a8 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -6,7 +6,10 @@ import ( "github.com/rs/xid" + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" @@ -23,15 +26,19 @@ type Manager interface { type managerImpl struct { store store.Store + accountManager s.AccountManager permissionsManager permissions.Manager resourcesManager resources.Manager + routersManager routers.Manager } -func NewManager(store store.Store, permissionsManager permissions.Manager, manager resources.Manager) Manager { +func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, permissionsManager: permissionsManager, - resourcesManager: manager, + resourcesManager: resourceManager, + routersManager: routersManager, + accountManager: accountManager, } } @@ -58,7 +65,14 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network network.ID = xid.New().String() - return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) + err = m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) + if err != nil { + return nil, fmt.Errorf("failed to save network: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta()) + + return network, nil } func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { @@ -82,6 +96,13 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network return nil, status.NewPermissionDeniedError() } + _, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta()) + return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) } @@ -94,20 +115,24 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw return status.NewPermissionDeniedError() } - unlock := m.store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + network, err := m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } - return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) if err != nil { return fmt.Errorf("failed to get resources in network: %w", err) } for _, resource := range resources { - err = m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resource.ID) + event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resource.ID) if err != nil { return fmt.Errorf("failed to delete resource: %w", err) } + eventsToStore = append(eventsToStore, event...) } routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) @@ -116,12 +141,33 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw } for _, router := range routers { - err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, router.ID) + event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, networkID, router.ID) if err != nil { return fmt.Errorf("failed to delete router: %w", err) } + eventsToStore = append(eventsToStore, event) } - return transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) + err = transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to delete network: %w", err) + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta()) + }) + + return nil }) + if err != nil { + return fmt.Errorf("failed to delete network: %w", err) + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil } diff --git a/management/server/networks/manager_test.go b/management/server/networks/manager_test.go index af1ce1cae..edd830c25 100644 --- a/management/server/networks/manager_test.go +++ b/management/server/networks/manager_test.go @@ -6,8 +6,10 @@ import ( "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/store" @@ -25,8 +27,10 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) { t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, &am) - manager := NewManager(s, permissionsManager, resourcesManager) + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) networks, err := manager.GetAllNetworks(ctx, accountID, userID) require.NoError(t, err) @@ -46,8 +50,10 @@ func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) { t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, &am) - manager := NewManager(s, permissionsManager, resourcesManager) + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) networks, err := manager.GetAllNetworks(ctx, accountID, userID) require.Error(t, err) @@ -67,8 +73,10 @@ func Test_GetNetworkReturnsNetwork(t *testing.T) { t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, &am) - manager := NewManager(s, permissionsManager, resourcesManager) + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) networks, err := manager.GetNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) @@ -88,8 +96,10 @@ func Test_GetNetworkReturnsPermissionDenied(t *testing.T) { t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, &am) - manager := NewManager(s, permissionsManager, resourcesManager) + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) network, err := manager.GetNetwork(ctx, accountID, userID, networkID) require.Error(t, err) @@ -111,8 +121,10 @@ func Test_CreateNetworkSuccessfully(t *testing.T) { t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, &am) - manager := NewManager(s, permissionsManager, resourcesManager) + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) createdNetwork, err := manager.CreateNetwork(ctx, userID, network) require.NoError(t, err) @@ -134,8 +146,10 @@ func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) { t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, &am) - manager := NewManager(s, permissionsManager, resourcesManager) + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) createdNetwork, err := manager.CreateNetwork(ctx, userID, network) require.Error(t, err) @@ -155,8 +169,10 @@ func Test_DeleteNetworkSuccessfully(t *testing.T) { t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, &am) - manager := NewManager(s, permissionsManager, resourcesManager) + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) err = manager.DeleteNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) @@ -175,8 +191,10 @@ func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) { t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, &am) - manager := NewManager(s, permissionsManager, resourcesManager) + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) err = manager.DeleteNetwork(ctx, accountID, userID, networkID) require.Error(t, err) @@ -198,8 +216,10 @@ func Test_UpdateNetworkSuccessfully(t *testing.T) { t.Cleanup(cleanUp) am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, &am) - manager := NewManager(s, permissionsManager, resourcesManager) + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) require.NoError(t, err) @@ -223,8 +243,10 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} permissionsManager := permissions.NewManagerMock() - resourcesManager := resources.NewManager(s, permissionsManager, &am) - manager := NewManager(s, permissionsManager, resourcesManager) + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) require.Error(t, err) diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 044c2dbb7..bc27d6c2f 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -6,10 +6,14 @@ import ( "fmt" s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/store" + nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) type Manager interface { @@ -20,19 +24,21 @@ type Manager interface { GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error - DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) error + DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) ([]func(), error) } type managerImpl struct { store store.Store permissionsManager permissions.Manager + groupsManager groups.Manager accountManager s.AccountManager } -func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { +func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, permissionsManager: permissionsManager, + groupsManager: groupsManager, accountManager: accountManager, } } @@ -92,21 +98,55 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc return nil, status.NewPermissionDeniedError() } - resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address) + resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address, resource.GroupIDs) if err != nil { return nil, fmt.Errorf("failed to create new network resource: %w", err) } - _, err = m.store.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) - if err == nil { - return nil, errors.New("resource already exists") - } - - err = m.store.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + if err == nil { + return errors.New("resource already exists") + } + + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + if err != nil { + return fmt.Errorf("failed to save network resource: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network.Name)) + } + eventsToStore = append(eventsToStore, event) + + res := nbtypes.Resource{ + ID: resource.ID, + Type: resource.Type.String(), + } + for _, groupID := range resource.GroupIDs { + event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, groupID, &res) + if err != nil { + return fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + return nil + }) if err != nil { return nil, fmt.Errorf("failed to create network resource: %w", err) } + for _, event := range eventsToStore { + event() + } + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -151,17 +191,50 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc resource.Domain = domain resource.Prefix = prefix - _, err = m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) - if err != nil { - return nil, fmt.Errorf("failed to get network resource: %w", err) - } + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } - oldResource, err := m.store.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) - if err == nil && oldResource.ID != resource.ID { - return nil, errors.New("new resource name already exists") - } + if network.ID != resource.NetworkID { + return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID) + } + + _, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) + if err != nil { + return fmt.Errorf("failed to get network resource: %w", err) + } + + oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + if err == nil && oldResource.ID != resource.ID { + return errors.New("new resource name already exists") + } + + oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) + if err != nil { + return fmt.Errorf("failed to get network resource: %w", err) + } + + err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + if err != nil { + return fmt.Errorf("failed to save network resource: %w", err) + } + + events, err := m.updateResourceGroups(ctx, transaction, resource, oldResource) + if err != nil { + return fmt.Errorf("failed to update resource groups: %w", err) + } + + eventsToStore = append(eventsToStore, events...) + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network.Name)) + }) + + return nil + }) - err = m.store.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) if err != nil { return nil, fmt.Errorf("failed to update network resource: %w", err) } @@ -171,6 +244,44 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc return resource, nil } +func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction store.Store, newResource, oldResource *types.NetworkResource) ([]func(), error) { + res := nbtypes.Resource{ + ID: newResource.ID, + Type: newResource.Type.String(), + } + + oldResourceGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, oldResource.AccountID, oldResource.ID) + if err != nil { + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + oldGroupsIds := make([]string, 0) + for _, group := range oldResourceGroups { + oldGroupsIds = append(oldGroupsIds, group.ID) + } + + var eventsToStore []func() + groupsToAdd := util.Difference(newResource.GroupIDs, oldGroupsIds) + for _, groupID := range groupsToAdd { + events, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, newResource.AccountID, groupID, &res) + if err != nil { + return nil, fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, events) + } + + groupsToRemove := util.Difference(oldGroupsIds, newResource.GroupIDs) + for _, groupID := range groupsToRemove { + events, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, newResource.AccountID, groupID, res.ID) + if err != nil { + return nil, fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, events) + } + + return eventsToStore, nil +} + func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) if err != nil { @@ -183,43 +294,68 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net unlock := m.store.AcquireWriteLockByUID(ctx, accountID) defer unlock() + var events []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - return m.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resourceID) + events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resourceID) + return err + }) - if err != nil { + if err != nil { return fmt.Errorf("failed to delete network resource: %w", err) } - - go m.accountManager.UpdateAccountPeers(ctx, accountID) - - return nil + + for _, event := range events { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil } -func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) error { +func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) ([]func(), error) { resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID) if err != nil { - return fmt.Errorf("failed to get network resource: %w", err) + return nil, fmt.Errorf("failed to get network resource: %w", err) + } + + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) } if resource.NetworkID != networkID { - return errors.New("resource not part of network") + return nil, errors.New("resource not part of network") } - account, err := transaction.GetAccount(ctx, accountID) + groups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, accountID, resourceID) if err != nil { - return fmt.Errorf("failed to get account: %w", err) + return nil, fmt.Errorf("failed to get resource groups: %w", err) } - account.DeleteResource(resource.ID) - err = transaction.SaveAccount(ctx, account) - if err != nil { - return fmt.Errorf("failed to save account: %w", err) + var eventsToStore []func() + + for _, group := range groups { + event, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, accountID, group.ID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to remove resource from group: %w", err) + } + eventsToStore = append(eventsToStore, event) } err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) if err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) + return nil, fmt.Errorf("failed to increment network serial: %w", err) } - return transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) + err = transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to delete network resource: %w", err) + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, accountID, resourceID, accountID, activity.NetworkResourceDeleted, resource.EventMeta(network.Name)) + }) + + return eventsToStore, nil } diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index e9ce8d280..993cd65df 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -2,11 +2,11 @@ package resources import ( "context" - "errors" "testing" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/permissions" @@ -20,14 +20,15 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { userID := "allowedUser" networkID := "testNetworkId" - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) @@ -40,14 +41,15 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) { userID := "invalidUser" networkID := "testNetworkId" - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) require.Error(t, err) @@ -59,14 +61,15 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) { accountID := "testAccountId" userID := "allowedUser" - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) require.NoError(t, err) @@ -78,14 +81,15 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) { accountID := "testAccountId" userID := "invalidUser" - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) require.Error(t, err) @@ -100,14 +104,15 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) { networkID := "testNetworkId" resourceID := "testResourceId" - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) @@ -121,14 +126,15 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) { networkID := "testNetworkId" resourceID := "testResourceId" - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) require.Error(t, err) @@ -154,7 +160,8 @@ func Test_CreateResourceSuccessfully(t *testing.T) { t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(store, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) createdResource, err := manager.CreateResource(ctx, userID, resource) require.NoError(t, err) @@ -179,7 +186,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) { t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(store, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -205,7 +213,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(store, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -230,7 +239,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) { t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(store, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -252,14 +262,15 @@ func Test_UpdateResourceSuccessfully(t *testing.T) { Address: "1.2.3.0/24", } - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.NoError(t, err) @@ -283,14 +294,15 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) { Address: "1.2.3.0/24", } - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -312,18 +324,18 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) { Address: "1.2.3.0/24", } - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) - require.Equal(t, errors.New("new resource name already exists"), err) require.Nil(t, updatedResource) } @@ -341,14 +353,15 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) { Address: "1.2.3.0/24", } - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -362,14 +375,15 @@ func Test_DeleteResourceSuccessfully(t *testing.T) { networkID := "testNetworkId" resourceID := "testResourceId" - s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) if err != nil { t.Fatal(err) } t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(s, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) @@ -389,7 +403,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) { t.Cleanup(cleanUp) permissionsManager := permissions.NewManagerMock() am := mock_server.MockAccountManager{} - manager := NewManager(store, permissionsManager, &am) + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) require.Error(t, err) diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index a1413dfd2..c8c19c951 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -6,11 +6,12 @@ import ( "net/netip" "regexp" + "github.com/rs/xid" + nbDomain "github.com/netbirdio/netbird/management/domain" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" - "github.com/rs/xid" "github.com/netbirdio/netbird/management/server/http/api" ) @@ -34,12 +35,13 @@ type NetworkResource struct { Name string Description string Type NetworkResourceType - Address string `gorm:"-"` + Address string `gorm:"-"` + GroupIDs []string `gorm:"-"` Domain string Prefix netip.Prefix `gorm:"serializer:json"` } -func NewNetworkResource(accountID, networkID, name, description, address string) (*NetworkResource, error) { +func NewNetworkResource(accountID, networkID, name, description, address string, groupIDs []string) (*NetworkResource, error) { resourceType, domain, prefix, err := GetResourceType(address) if err != nil { return nil, fmt.Errorf("invalid address: %w", err) @@ -55,6 +57,7 @@ func NewNetworkResource(accountID, networkID, name, description, address string) Address: address, Domain: domain, Prefix: prefix, + GroupIDs: groupIDs, }, nil } @@ -81,6 +84,7 @@ func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) { n.Description = *req.Description } n.Address = req.Address + n.GroupIDs = req.Groups } func (n *NetworkResource) Copy() *NetworkResource { @@ -94,6 +98,7 @@ func (n *NetworkResource) Copy() *NetworkResource { Address: n.Address, Domain: n.Domain, Prefix: n.Prefix, + GroupIDs: n.GroupIDs, } } @@ -137,6 +142,10 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network return r } +func (n *NetworkResource) EventMeta(networkName string) map[string]any { + return map[string]any{"name": n.Name, "type": n.Type, "network_name": networkName} +} + // GetResourceType returns the type of the resource based on the address func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, error) { if prefix, err := netip.ParsePrefix(address); err == nil { diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 2103beb06..dc092d8ec 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -8,7 +8,9 @@ import ( "github.com/rs/xid" s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/store" @@ -21,6 +23,7 @@ type Manager interface { GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error + DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, routerID string) (func(), error) } type managerImpl struct { @@ -29,6 +32,9 @@ type managerImpl struct { accountManager s.AccountManager } +type mockManager struct { +} + func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, @@ -80,13 +86,32 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t return nil, status.NewPermissionDeniedError() } - router.ID = xid.New().String() + var network *networkTypes.Network + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } - err = m.store.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if network.ID != router.NetworkID { + return status.NewNetworkNotFoundError(router.NetworkID) + } + + router.ID = xid.New().String() + + err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if err != nil { + return fmt.Errorf("failed to create network router: %w", err) + } + + return nil + }) if err != nil { - return nil, fmt.Errorf("failed to create network router: %w", err) + return nil, err } + m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network.Name)) + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -122,11 +147,30 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t return nil, status.NewPermissionDeniedError() } - err = m.store.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + var network *networkTypes.Network + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != router.NetworkID { + return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) + } + + err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if err != nil { + return fmt.Errorf("failed to update network router: %w", err) + } + + return nil + }) if err != nil { - return nil, fmt.Errorf("failed to update network router: %w", err) + return nil, err } + m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network.Name)) + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -141,12 +185,77 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo return status.NewPermissionDeniedError() } - err = m.store.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) + var event func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, networkID, routerID) + return err + }) if err != nil { - return fmt.Errorf("failed to delete network router: %w", err) + return err } + event() + go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil } + +func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, routerID string) (func(), error) { + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + router, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to get network router: %w", err) + } + + if router.NetworkID != networkID { + return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID) + } + + err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to delete network router: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, "", routerID, accountID, activity.NetworkRouterDeleted, router.EventMeta(network.Name)) + } + + return event, nil +} + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) { + return []*types.NetworkRouter{}, nil +} + +func (m *mockManager) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { + return map[string][]*types.NetworkRouter{}, nil +} + +func (m *mockManager) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + return router, nil +} + +func (m *mockManager) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) { + return &types.NetworkRouter{}, nil +} + +func (m *mockManager) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + return router, nil +} + +func (m *mockManager) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { + return nil +} + +func (m *mockManager) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, routerID string) (func(), error) { + return func() {}, nil +} diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index b1491d2d1..4c2e11e90 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -68,3 +68,7 @@ func (n *NetworkRouter) Copy() *NetworkRouter { Metric: n.Metric, } } + +func (n *NetworkRouter) EventMeta(networkName string) map[string]any { + return map[string]any{"network_name": networkName} +} diff --git a/management/server/networks/types/network.go b/management/server/networks/types/network.go index d95252382..a4ba7b821 100644 --- a/management/server/networks/types/network.go +++ b/management/server/networks/types/network.go @@ -22,7 +22,7 @@ func NewNetwork(accountId, name, description string) *Network { } } -func (n *Network) ToAPIResponse(routerIDs []string, resourceIDs []string, routingPeersCount int) *api.Network { +func (n *Network) ToAPIResponse(routerIDs []string, resourceIDs []string, routingPeersCount int, policyIDs []string) *api.Network { return &api.Network{ Id: n.ID, Name: n.Name, @@ -30,6 +30,7 @@ func (n *Network) ToAPIResponse(routerIDs []string, resourceIDs []string, routin Routers: routerIDs, Resources: resourceIDs, RoutingPeersCount: routingPeersCount, + Policies: policyIDs, } } @@ -49,3 +50,7 @@ func (n *Network) Copy() *Network { Description: n.Description, } } + +func (n *Network) EventMeta() map[string]any { + return map[string]any{"name": n.Name} +} diff --git a/management/server/route_test.go b/management/server/route_test.go index ef884ec5d..f780f8c99 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -9,13 +9,14 @@ import ( "testing" "time" - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - networkTypes "github.com/netbirdio/netbird/management/server/networks/types" "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -2501,7 +2502,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { peerE := account.GetPeer("peerE") router1 := getNetworkRouterByID(account, "router1") route1 := getNetworkResourceByID(account, "resource1").ToRoute(peerE, router1) - policies := account.GetPoliciesForNetworkResourceRoute(route1) + policies := account.GetPoliciesForNetworkResource(string(route1.ID)) assert.Len(t, policies, 1, "resource1 should have exactly 1 policy applied directly") // Test case: Resource2 is applied to an access control group (dev), @@ -2509,20 +2510,20 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { peerA := account.GetPeer("peerA") router2 := getNetworkRouterByID(account, "router2") route2 := getNetworkResourceByID(account, "resource2").ToRoute(peerA, router2) - policies = account.GetPoliciesForNetworkResourceRoute(route2) + policies = account.GetPoliciesForNetworkResource(string(route2.ID)) assert.Len(t, policies, 1, "resource2 should have exactly 1 policy applied via access control group") // Test case: Resource3 is not applied to any access control group or policy router3 := getNetworkRouterByID(account, "router3") route3 := getNetworkResourceByID(account, "resource3").ToRoute(peerE, router3) - policies = account.GetPoliciesForNetworkResourceRoute(route3) + policies = account.GetPoliciesForNetworkResource(string(route3.ID)) assert.Len(t, policies, 0, "resource3 should have no policies applied") // Test case: Resource4 is applied to the access control groups (restrictQA and unrestrictedQA), // which is part of the destination in the policies (policyResource3 and policyResource4) router4 := getNetworkRouterByID(account, "router4") route4 := getNetworkResourceByID(account, "resource4").ToRoute(peerA, router4) - policies = account.GetPoliciesForNetworkResourceRoute(route4) + policies = account.GetPoliciesForNetworkResource(string(route4.ID)) assert.Len(t, policies, 2, "resource4 should have exactly 2 policy applied via access control groups") }) diff --git a/management/server/status/error.go b/management/server/status/error.go index d65931b5a..d9cab0231 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -178,3 +178,11 @@ func NewPermissionDeniedError() error { func NewPermissionValidationError(err error) error { return Errorf(PermissionDenied, "failed to vlidate user permissions: %s", err) } + +func NewResourceNotPartOfNetworkError(resourceID, networkID string) error { + return Errorf(BadRequest, "resource %s is not part of the network %s", resourceID, networkID) +} + +func NewRouterNotPartOfNetworkError(routerID, networkID string) error { + return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID) +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index ed4d2fb28..62b004f9c 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -588,6 +588,25 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr return groups, nil } +func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + var groups []*types.Group + + likePattern := `%"ID":"` + resourceID + `"%` + + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("resources LIKE ?", likePattern). + Find(&groups) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, result.Error + } + + return groups, nil +} + func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { var accounts []types.Account result := s.db.Find(&accounts) @@ -1019,6 +1038,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string return nil } +// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { var group types.Group result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) @@ -1044,6 +1064,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer return nil } +// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { var group types.Group result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) @@ -1070,6 +1091,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// AddResourceToGroup adds a resource to a group. Method always needs to run n a transaction func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error { var group types.Group result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) @@ -1096,6 +1118,32 @@ func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, gro return nil } +// RemoveResourceFromGroup removes a resource from a group. Method always needs to run in a transaction +func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error { + var group types.Group + result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.NewGroupNotFoundError(groupID) + } + + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) + } + + for i, res := range group.Resources { + if res.ID == resourceID { + group.Resources = append(group.Resources[:i], group.Resources[i+1:]...) + break + } + } + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group: %s", err) + } + + return nil +} + // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID) diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index ec4b49534..845bc8fd4 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -2494,7 +2494,7 @@ func TestSqlStore_SaveNetworkResource(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" networkID := "ct286bi7qv930dsrrug0" - netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com") + netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com", []string{}) require.NoError(t, err) err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource) @@ -2529,3 +2529,35 @@ func TestSqlStore_DeleteNetworkResource(t *testing.T) { require.Equal(t, status.NotFound, sErr.Type()) require.Nil(t, netResource) } + +func TestSqlStore_AddAndRemoveResourceFromGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + resourceId := "ctc4nci7qv9061u6ilfg" + groupID := "cs1tnh0hhcjnqoiuebeg" + + res := &types.Resource{ + ID: resourceId, + Type: "host", + } + err = store.AddResourceToGroup(context.Background(), accountID, groupID, res) + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err) + require.Contains(t, group.Resources, *res) + + groups, err := store.GetResourceGroups(context.Background(), LockingStrengthShare, accountID, resourceId) + require.NoError(t, err) + require.Len(t, groups, 1) + + err = store.RemoveResourceFromGroup(context.Background(), accountID, groupID, res.ID) + require.NoError(t, err) + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err) + require.NotContains(t, group.Resources, *res) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 6b0e862c4..d9dc6b8f7 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -74,7 +74,8 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*types.Group, error) + GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error @@ -99,6 +100,7 @@ type Store interface { AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error + RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) diff --git a/management/server/types/account.go b/management/server/types/account.go index 41e244d06..36efa6590 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1251,7 +1251,7 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peerI routes := a.getRoutingPeerNetworkResourcesRoutes(ctx, peerID) for _, route := range routes { - resourceAppliedPolicies := a.GetPoliciesForNetworkResourceRoute(route) + resourceAppliedPolicies := a.GetPoliciesForNetworkResource(string(route.ID)) distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) rules := a.getRouteFirewallRules(ctx, peerID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) @@ -1261,13 +1261,13 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peerI return routesFirewallRules } -// getNetworkResourceGroups retrieves all groups associated with the given network resource route. -func (a *Account) getNetworkResourceGroups(route *route.Route) []*Group { +// getNetworkResourceGroups retrieves all groups associated with the given network resource. +func (a *Account) getNetworkResourceGroups(resourceID string) []*Group { var networkResourceGroups []*Group for _, group := range a.Groups { for _, resource := range group.Resources { - if resource.ID == string(route.ID) { + if resource.ID == resourceID { networkResourceGroups = append(networkResourceGroups, group) } } @@ -1304,13 +1304,13 @@ func (a *Account) getNetworkResources(networkID string) []*resourceTypes.Network return resources } -// GetPoliciesForNetworkResourceRoute retrieves the list of policies that apply to a specific network resource route. +// GetPoliciesForNetworkResource retrieves the list of policies that apply to a specific network resource. // A policy is deemed applicable if its destination groups include any of the given network resource groups -// or if its destination resource explicitly matches the provided route. -func (a *Account) GetPoliciesForNetworkResourceRoute(route *route.Route) []*Policy { +// or if its destination resource explicitly matches the provided resource. +func (a *Account) GetPoliciesForNetworkResource(resourceId string) []*Policy { var resourceAppliedPolicies []*Policy - networkResourceGroups := a.getNetworkResourceGroups(route) + networkResourceGroups := a.getNetworkResourceGroups(resourceId) for _, policy := range a.Policies { if !policy.Enabled { @@ -1329,7 +1329,7 @@ func (a *Account) GetPoliciesForNetworkResourceRoute(route *route.Route) []*Poli } } - if rule.DestinationResource.ID == string(route.ID) { + if rule.DestinationResource.ID == resourceId { resourceAppliedPolicies = append(resourceAppliedPolicies, policy) } } @@ -1338,12 +1338,31 @@ func (a *Account) GetPoliciesForNetworkResourceRoute(route *route.Route) []*Poli return resourceAppliedPolicies } +func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string { + networkResources := a.getNetworkResources(networkID) + + policieIDs := map[string]struct{}{} + for _, resource := range networkResources { + resourceAppliedPolicies := a.GetPoliciesForNetworkResource(resource.ID) + for _, policy := range resourceAppliedPolicies { + policieIDs[policy.ID] = struct{}{} + } + } + + result := make([]string, 0, len(policieIDs)) + for id := range policieIDs { + result = append(result, id) + } + + return result +} + // getNetworkResourcesRoutes convert the network resources list to routes list. func (a *Account) getNetworkResourcesRoutes(resources []*resourceTypes.NetworkResource, router *routerTypes.NetworkRouter, peer *nbpeer.Peer) []*route.Route { routes := make([]*route.Route, 0, len(resources)) for _, resource := range resources { resourceRoute := resource.ToRoute(peer, router) - resourceAppliedPolicies := a.GetPoliciesForNetworkResourceRoute(resourceRoute) + resourceAppliedPolicies := a.GetPoliciesForNetworkResource(string(resourceRoute.ID)) // distribute the resource routes only if there is policy applied to it if len(resourceAppliedPolicies) > 0 {