Merge branch 'feature/new-networks-concept' into refactor/new-concept-netmap-gen

This commit is contained in:
bcmmbaga 2024-12-18 14:48:20 +03:00
commit be341db10a
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
26 changed files with 775 additions and 177 deletions

View File

@ -169,6 +169,10 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
log.Debugf("registering handler %s with priority %d", handler, priority) log.Debugf("registering handler %s with priority %d", handler, priority)
for _, domain := range domains { for _, domain := range domains {
if domain == "" {
log.Warn("skipping empty domain")
continue
}
s.handlerChain.AddHandler(domain, handler, priority, nil) s.handlerChain.AddHandler(domain, handler, priority, nil)
s.handlerPriorities[domain] = priority s.handlerPriorities[domain] = priority
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
@ -188,6 +192,10 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
// Only deregister from service if no handlers remain // Only deregister from service if no handlers remain
if !s.handlerChain.HasHandlers(domain) { if !s.handlerChain.HasHandlers(domain) {
if domain == "" {
log.Warn("skipping empty domain")
continue
}
s.service.DeregisterMux(nbdns.NormalizeZone(domain)) s.service.DeregisterMux(nbdns.NormalizeZone(domain))
} }
} }

View File

@ -27,7 +27,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, domains []string) *DNSFor
return &DNSForwarder{ return &DNSForwarder{
listenAddress: listenAddress, listenAddress: listenAddress,
ttl: ttl, ttl: ttl,
domains: domains, domains: filterDomains(domains),
} }
} }
@ -35,10 +35,6 @@ func (f *DNSForwarder) Listen() error {
log.Infof("listen DNS forwarder on address=%s", f.listenAddress) log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
mux := dns.NewServeMux() mux := dns.NewServeMux()
for _, d := range f.domains {
mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery)
}
dnsServer := &dns.Server{ dnsServer := &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "udp", Net: "udp",
@ -54,10 +50,11 @@ func (f *DNSForwarder) UpdateDomains(domains []string) {
f.mux.HandleRemove(d) f.mux.HandleRemove(d)
} }
for _, d := range f.domains { newDomains := filterDomains(domains)
f.mux.HandleFunc(nbdns.NormalizeZone(d), f.handleDNSQuery) for _, d := range newDomains {
f.mux.HandleFunc(d, f.handleDNSQuery)
} }
f.domains = domains f.domains = newDomains
} }
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
@ -141,3 +138,16 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }
} }
// filterDomains returns a list of normalized domains
func filterDomains(domains []string) []string {
newDomains := make([]string, 0, len(domains))
for _, d := range domains {
if d == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, nbdns.NormalizeZone(d))
}
return newDomains
}

View File

@ -808,12 +808,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
routedDomains, routes := toRoutes(networkMap.GetRoutes()) routedDomains, routes := toRoutes(networkMap.GetRoutes())
e.updateDNSForwarder(dnsRouteFeatureFlag, routedDomains)
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update clientRoutes, err: %v", err)
} }
e.updateDNSForwarder(dnsRouteFeatureFlag, routedDomains)
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers()) e.updateOfflinePeers(networkMap.GetOfflinePeers())

View File

@ -67,7 +67,6 @@ func (d *DnsInterceptor) AddRoute(context.Context) error {
func (d *DnsInterceptor) RemoveRoute() error { func (d *DnsInterceptor) RemoveRoute() error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock()
var merr *multierror.Error var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains { for domain, prefixes := range d.interceptedDomains {
@ -89,6 +88,7 @@ func (d *DnsInterceptor) RemoveRoute() error {
} }
clear(d.interceptedDomains) clear(d.interceptedDomains)
d.mu.Unlock()
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute) d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)

View File

@ -276,10 +276,10 @@ var (
userManager := users.NewManager(store) userManager := users.NewManager(store)
settingsManager := settings.NewManager(store) settingsManager := settings.NewManager(store)
permissionsManager := permissions.NewManager(userManager, settingsManager) permissionsManager := permissions.NewManager(userManager, settingsManager)
resourcesManager := resources.NewManager(store, permissionsManager, accountManager) groupsManager := groups.NewManager(store, permissionsManager, accountManager)
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
routersManager := routers.NewManager(store, permissionsManager, accountManager) routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager) networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
groupsManager := groups.NewManager(store, permissionsManager)
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil { if err != nil {

View File

@ -154,6 +154,21 @@ const (
AccountRoutingPeerDNSResolutionEnabled Activity = 71 AccountRoutingPeerDNSResolutionEnabled Activity = 71
AccountRoutingPeerDNSResolutionDisabled Activity = 72 AccountRoutingPeerDNSResolutionDisabled Activity = 72
NetworkCreated Activity = 73
NetworkUpdated Activity = 74
NetworkDeleted Activity = 75
NetworkResourceCreated Activity = 76
NetworkResourceUpdated Activity = 77
NetworkResourceDeleted Activity = 78
NetworkRouterCreated Activity = 79
NetworkRouterUpdated Activity = 80
NetworkRouterDeleted Activity = 81
ResourceAddedToGroup Activity = 82
ResourceRemovedFromGroup Activity = 83
) )
var activityMap = map[Activity]Code{ var activityMap = map[Activity]Code{
@ -234,6 +249,21 @@ var activityMap = map[Activity]Code{
AccountRoutingPeerDNSResolutionEnabled: {"Account routing peer DNS resolution enabled", "account.setting.routing.peer.dns.resolution.enable"}, AccountRoutingPeerDNSResolutionEnabled: {"Account routing peer DNS resolution enabled", "account.setting.routing.peer.dns.resolution.enable"},
AccountRoutingPeerDNSResolutionDisabled: {"Account routing peer DNS resolution disabled", "account.setting.routing.peer.dns.resolution.disable"}, AccountRoutingPeerDNSResolutionDisabled: {"Account routing peer DNS resolution disabled", "account.setting.routing.peer.dns.resolution.disable"},
NetworkCreated: {"Network created", "network.create"},
NetworkUpdated: {"Network updated", "network.update"},
NetworkDeleted: {"Network deleted", "network.delete"},
NetworkResourceCreated: {"Network resource created", "network.resource.create"},
NetworkResourceUpdated: {"Network resource updated", "network.resource.update"},
NetworkResourceDeleted: {"Network resource deleted", "network.resource.delete"},
NetworkRouterCreated: {"Network router created", "network.router.create"},
NetworkRouterUpdated: {"Network router updated", "network.router.update"},
NetworkRouterDeleted: {"Network router deleted", "network.router.delete"},
ResourceAddedToGroup: {"Resource added to group", "resource.group.add"},
ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@ -12,18 +14,26 @@ import (
type Manager interface { type Manager interface {
GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error)
GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error)
AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error
AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID string, resourceID *types.Resource) (func(), error)
RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID, resourceID string) (func(), error)
} }
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
permissionsManager permissions.Manager permissionsManager permissions.Manager
accountManager s.AccountManager
} }
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
accountManager: accountManager,
} }
} }
@ -58,7 +68,44 @@ func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID,
return err return err
} }
return m.store.AddResourceToGroup(ctx, accountID, groupID, resource) event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, groupID, resource)
if err != nil {
return fmt.Errorf("error adding resource to group: %w", err)
}
event()
return nil
}
func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID string, resource *types.Resource) (func(), error) {
err := transaction.AddResourceToGroup(ctx, accountID, groupID, resource)
if err != nil {
return nil, fmt.Errorf("error adding resource to group: %w", err)
}
event := func() {
m.accountManager.StoreEvent(ctx, accountID, groupID, accountID, activity.ResourceAddedToGroup, nil)
}
return event, nil
}
func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID, resourceID string) (func(), error) {
err := transaction.RemoveResourceFromGroup(ctx, accountID, groupID, resourceID)
if err != nil {
return nil, fmt.Errorf("error removing resource from group: %w", err)
}
event := func() {
m.accountManager.StoreEvent(ctx, accountID, groupID, accountID, activity.ResourceRemovedFromGroup, nil)
}
return event, nil
}
func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) {
return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID)
} }
func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum { func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum {
@ -97,3 +144,31 @@ func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum
} }
return groupsInfo return groupsInfo
} }
func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) {
return map[string]*types.Group{}, nil
}
func (m *mockManager) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) {
return []*types.Group{}, nil
}
func (m *mockManager) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error {
return nil
}
func (m *mockManager) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID string, resourceID *types.Resource) (func(), error) {
return func() {
// noop
}, nil
}
func (m *mockManager) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, groupID, resourceID string) (func(), error) {
return func() {
// noop
}, nil
}
func NewManagerMock() Manager {
return &mockManager{}
}

View File

@ -1255,11 +1255,18 @@ components:
items: items:
type: string type: string
example: ch8i4ug6lnn4g9hqv7m1 example: ch8i4ug6lnn4g9hqv7m1
policies:
description: List of policy IDs associated with the network
type: array
items:
type: string
example: ch8i4ug6lnn4g9hqv7m2
required: required:
- id - id
- routers - routers
- resources - resources
- routing_peers_count - routing_peers_count
- policies
- $ref: '#/components/schemas/NetworkRequest' - $ref: '#/components/schemas/NetworkRequest'
NetworkResourceMinimum: NetworkResourceMinimum:
type: object type: object

View File

@ -530,6 +530,9 @@ type Network struct {
// Name Network name // Name Network name
Name string `json:"name"` Name string `json:"name"`
// Policies List of policy IDs associated with the network
Policies []string `json:"policies"`
// Resources List of network resource IDs associated with the network // Resources List of network resource IDs associated with the network
Resources []string `json:"resources"` Resources []string `json:"resources"`
@ -551,7 +554,7 @@ type NetworkRequest struct {
// NetworkResource defines model for NetworkResource. // NetworkResource defines model for NetworkResource.
type NetworkResource struct { type NetworkResource struct {
// Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com) // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com)
Address string `json:"address"` Address string `json:"address"`
// Description Network resource description // Description Network resource description
@ -572,7 +575,7 @@ type NetworkResource struct {
// NetworkResourceMinimum defines model for NetworkResourceMinimum. // NetworkResourceMinimum defines model for NetworkResourceMinimum.
type NetworkResourceMinimum struct { type NetworkResourceMinimum struct {
// Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com) // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com)
Address string `json:"address"` Address string `json:"address"`
// Description Network resource description // Description Network resource description
@ -584,7 +587,7 @@ type NetworkResourceMinimum struct {
// NetworkResourceRequest defines model for NetworkResourceRequest. // NetworkResourceRequest defines model for NetworkResourceRequest.
type NetworkResourceRequest struct { type NetworkResourceRequest struct {
// Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com) // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com)
Address string `json:"address"` Address string `json:"address"`
// Description Network resource description // Description Network resource description

View File

@ -98,7 +98,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, networksMa
routes.AddEndpoints(api.AccountManager, authCfg, router) routes.AddEndpoints(api.AccountManager, authCfg, router)
dns.AddEndpoints(api.AccountManager, authCfg, router) dns.AddEndpoints(api.AccountManager, authCfg, router)
events.AddEndpoints(api.AccountManager, authCfg, router) events.AddEndpoints(api.AccountManager, authCfg, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager.GetAccountIDFromToken, authCfg, router) networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router)
return rootRouter, nil return rootRouter, nil
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/configs"
@ -27,17 +28,18 @@ type handler struct {
networksManager networks.Manager networksManager networks.Manager
resourceManager resources.Manager resourceManager resources.Manager
routerManager routers.Manager routerManager routers.Manager
accountManager s.AccountManager
groupsManager groups.Manager groupsManager groups.Manager
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
addRouterEndpoints(routerManager, extractFromToken, authCfg, router) addRouterEndpoints(routerManager, extractFromToken, authCfg, router)
addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router) addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router)
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, extractFromToken, authCfg) networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg)
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
@ -45,12 +47,13 @@ func AddEndpoints(networksManager networks.Manager, resourceManager resources.Ma
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
} }
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler {
return &handler{ return &handler{
networksManager: networksManager, networksManager: networksManager,
resourceManager: resourceManager, resourceManager: resourceManager,
routerManager: routerManager, routerManager: routerManager,
groupsManager: groupsManager, groupsManager: groupsManager,
accountManager: accountManager,
extractFromToken: extractFromToken, extractFromToken: extractFromToken,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -91,7 +94,13 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
return return
} }
util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups)) account, err := h.accountManager.GetAccount(r.Context(), accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account))
} }
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
@ -119,7 +128,15 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
return return
} }
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0)) account, err := h.accountManager.GetAccount(r.Context(), accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policyIDs := account.GetPoliciesAppliedInNetwork(network.ID)
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs))
} }
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
@ -149,7 +166,15 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
return return
} }
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount)) account, err := h.accountManager.GetAccount(r.Context(), accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policyIDs := account.GetPoliciesAppliedInNetwork(networkID)
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
} }
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
@ -191,7 +216,15 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
return return
} }
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount)) account, err := h.accountManager.GetAccount(r.Context(), accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policyIDs := account.GetPoliciesAppliedInNetwork(networkID)
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
} }
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
@ -256,11 +289,12 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne
return routerIDs, resourceIDs, peerCounter, nil return routerIDs, resourceIDs, peerCounter, nil
} }
func (h *handler) generateNetworkResponse(networks []*types.Network, routers map[string][]*routerTypes.NetworkRouter, resourceIDs map[string][]string, groups map[string]*nbtypes.Group) []*api.Network { func (h *handler) generateNetworkResponse(networks []*types.Network, routers map[string][]*routerTypes.NetworkRouter, resourceIDs map[string][]string, groups map[string]*nbtypes.Group, account *nbtypes.Account) []*api.Network {
var networkResponse []*api.Network var networkResponse []*api.Network
for _, network := range networks { for _, network := range networks {
routerIDs, peerCounter := getRouterIDs(network, routers, groups) routerIDs, peerCounter := getRouterIDs(network, routers, groups)
networkResponse = append(networkResponse, network.ToAPIResponse(routerIDs, resourceIDs[network.ID], peerCounter)) policyIDs := account.GetPoliciesAppliedInNetwork(network.ID)
networkResponse = append(networkResponse, network.ToAPIResponse(routerIDs, resourceIDs[network.ID], peerCounter, policyIDs))
} }
return networkResponse return networkResponse
} }

View File

@ -14,7 +14,6 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/networks/resources/types"
nbtypes "github.com/netbirdio/netbird/management/server/types"
) )
type resourceHandler struct { type resourceHandler struct {
@ -130,18 +129,6 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
return return
} }
res := nbtypes.Resource{
ID: resource.ID,
Type: resource.Type.String(),
}
for _, groupID := range req.Groups {
err = h.groupsManager.AddResourceToGroup(r.Context(), accountID, userID, groupID, &res)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@ -205,18 +192,6 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
return return
} }
res := nbtypes.Resource{
ID: resource.ID,
Type: resource.Type.String(),
}
for _, groupID := range req.Groups {
err = h.groupsManager.AddResourceToGroup(r.Context(), accountID, userID, groupID, &res)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)

View File

@ -6,7 +6,10 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -23,15 +26,19 @@ type Manager interface {
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
accountManager s.AccountManager
permissionsManager permissions.Manager permissionsManager permissions.Manager
resourcesManager resources.Manager resourcesManager resources.Manager
routersManager routers.Manager
} }
func NewManager(store store.Store, permissionsManager permissions.Manager, manager resources.Manager) Manager { func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
resourcesManager: manager, resourcesManager: resourceManager,
routersManager: routersManager,
accountManager: accountManager,
} }
} }
@ -58,7 +65,14 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
network.ID = xid.New().String() network.ID = xid.New().String()
return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) err = m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network)
if err != nil {
return nil, fmt.Errorf("failed to save network: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta())
return network, nil
} }
func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) {
@ -82,6 +96,13 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err)
}
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network)
} }
@ -94,20 +115,24 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
unlock := m.store.AcquireWriteLockByUID(ctx, accountID) network, err := m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID)
defer unlock() if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get resources in network: %w", err) return fmt.Errorf("failed to get resources in network: %w", err)
} }
for _, resource := range resources { for _, resource := range resources {
err = m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resource.ID) event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resource.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete resource: %w", err) return fmt.Errorf("failed to delete resource: %w", err)
} }
eventsToStore = append(eventsToStore, event...)
} }
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
@ -116,12 +141,33 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
} }
for _, router := range routers { for _, router := range routers {
err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, router.ID) event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, networkID, router.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete router: %w", err) return fmt.Errorf("failed to delete router: %w", err)
} }
eventsToStore = append(eventsToStore, event)
} }
return transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) err = transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to delete network: %w", err)
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta())
}) })
return nil
})
if err != nil {
return fmt.Errorf("failed to delete network: %w", err)
}
for _, event := range eventsToStore {
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
} }

View File

@ -6,8 +6,10 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/types" "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@ -25,8 +27,10 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(s, permissionsManager, resourcesManager) routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
networks, err := manager.GetAllNetworks(ctx, accountID, userID) networks, err := manager.GetAllNetworks(ctx, accountID, userID)
require.NoError(t, err) require.NoError(t, err)
@ -46,8 +50,10 @@ func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(s, permissionsManager, resourcesManager) routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
networks, err := manager.GetAllNetworks(ctx, accountID, userID) networks, err := manager.GetAllNetworks(ctx, accountID, userID)
require.Error(t, err) require.Error(t, err)
@ -67,8 +73,10 @@ func Test_GetNetworkReturnsNetwork(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(s, permissionsManager, resourcesManager) routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
networks, err := manager.GetNetwork(ctx, accountID, userID, networkID) networks, err := manager.GetNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err) require.NoError(t, err)
@ -88,8 +96,10 @@ func Test_GetNetworkReturnsPermissionDenied(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(s, permissionsManager, resourcesManager) routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
network, err := manager.GetNetwork(ctx, accountID, userID, networkID) network, err := manager.GetNetwork(ctx, accountID, userID, networkID)
require.Error(t, err) require.Error(t, err)
@ -111,8 +121,10 @@ func Test_CreateNetworkSuccessfully(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(s, permissionsManager, resourcesManager) routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
createdNetwork, err := manager.CreateNetwork(ctx, userID, network) createdNetwork, err := manager.CreateNetwork(ctx, userID, network)
require.NoError(t, err) require.NoError(t, err)
@ -134,8 +146,10 @@ func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(s, permissionsManager, resourcesManager) routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
createdNetwork, err := manager.CreateNetwork(ctx, userID, network) createdNetwork, err := manager.CreateNetwork(ctx, userID, network)
require.Error(t, err) require.Error(t, err)
@ -155,8 +169,10 @@ func Test_DeleteNetworkSuccessfully(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(s, permissionsManager, resourcesManager) routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
err = manager.DeleteNetwork(ctx, accountID, userID, networkID) err = manager.DeleteNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err) require.NoError(t, err)
@ -175,8 +191,10 @@ func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(s, permissionsManager, resourcesManager) routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
err = manager.DeleteNetwork(ctx, accountID, userID, networkID) err = manager.DeleteNetwork(ctx, accountID, userID, networkID)
require.Error(t, err) require.Error(t, err)
@ -198,8 +216,10 @@ func Test_UpdateNetworkSuccessfully(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(s, permissionsManager, resourcesManager) routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network)
require.NoError(t, err) require.NoError(t, err)
@ -223,8 +243,10 @@ func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(s, permissionsManager, resourcesManager) routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network)
require.Error(t, err) require.Error(t, err)

View File

@ -6,10 +6,14 @@ import (
"fmt" "fmt"
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
) )
type Manager interface { type Manager interface {
@ -20,19 +24,21 @@ type Manager interface {
GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error)
UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error)
DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error
DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) error DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) ([]func(), error)
} }
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
permissionsManager permissions.Manager permissionsManager permissions.Manager
groupsManager groups.Manager
accountManager s.AccountManager accountManager s.AccountManager
} }
func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
groupsManager: groupsManager,
accountManager: accountManager, accountManager: accountManager,
} }
} }
@ -92,21 +98,55 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address) resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address, resource.GroupIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create new network resource: %w", err) return nil, fmt.Errorf("failed to create new network resource: %w", err)
} }
_, err = m.store.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
if err == nil { if err == nil {
return nil, errors.New("resource already exists") return errors.New("resource already exists")
} }
err = m.store.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
}
event := func() {
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network.Name))
}
eventsToStore = append(eventsToStore, event)
res := nbtypes.Resource{
ID: resource.ID,
Type: resource.Type.String(),
}
for _, groupID := range resource.GroupIDs {
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, groupID, &res)
if err != nil {
return fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, event)
}
return nil
})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create network resource: %w", err) return nil, fmt.Errorf("failed to create network resource: %w", err)
} }
for _, event := range eventsToStore {
event()
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil return resource, nil
@ -151,17 +191,50 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
resource.Domain = domain resource.Domain = domain
resource.Prefix = prefix resource.Prefix = prefix
_, err = m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network resource: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
oldResource, err := m.store.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) if network.ID != resource.NetworkID {
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
}
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
if err != nil {
return fmt.Errorf("failed to get network resource: %w", err)
}
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
if err == nil && oldResource.ID != resource.ID { if err == nil && oldResource.ID != resource.ID {
return nil, errors.New("new resource name already exists") return errors.New("new resource name already exists")
} }
err = m.store.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
if err != nil {
return fmt.Errorf("failed to get network resource: %w", err)
}
err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
}
events, err := m.updateResourceGroups(ctx, transaction, resource, oldResource)
if err != nil {
return fmt.Errorf("failed to update resource groups: %w", err)
}
eventsToStore = append(eventsToStore, events...)
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network.Name))
})
return nil
})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to update network resource: %w", err) return nil, fmt.Errorf("failed to update network resource: %w", err)
} }
@ -171,6 +244,44 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return resource, nil return resource, nil
} }
func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction store.Store, newResource, oldResource *types.NetworkResource) ([]func(), error) {
res := nbtypes.Resource{
ID: newResource.ID,
Type: newResource.Type.String(),
}
oldResourceGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, oldResource.AccountID, oldResource.ID)
if err != nil {
return nil, fmt.Errorf("failed to get resource groups: %w", err)
}
oldGroupsIds := make([]string, 0)
for _, group := range oldResourceGroups {
oldGroupsIds = append(oldGroupsIds, group.ID)
}
var eventsToStore []func()
groupsToAdd := util.Difference(newResource.GroupIDs, oldGroupsIds)
for _, groupID := range groupsToAdd {
events, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, newResource.AccountID, groupID, &res)
if err != nil {
return nil, fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, events)
}
groupsToRemove := util.Difference(oldGroupsIds, newResource.GroupIDs)
for _, groupID := range groupsToRemove {
events, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, newResource.AccountID, groupID, res.ID)
if err != nil {
return nil, fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, events)
}
return eventsToStore, nil
}
func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write)
if err != nil { if err != nil {
@ -183,43 +294,68 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
unlock := m.store.AcquireWriteLockByUID(ctx, accountID) unlock := m.store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
var events []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
return m.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resourceID) events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resourceID)
return err
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to delete network resource: %w", err) return fmt.Errorf("failed to delete network resource: %w", err)
} }
for _, event := range events {
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID) go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }
func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) error { func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) ([]func(), error) {
resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID) resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network resource: %w", err) return nil, fmt.Errorf("failed to get network resource: %w", err)
}
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err)
} }
if resource.NetworkID != networkID { if resource.NetworkID != networkID {
return errors.New("resource not part of network") return nil, errors.New("resource not part of network")
} }
account, err := transaction.GetAccount(ctx, accountID) groups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, accountID, resourceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account: %w", err) return nil, fmt.Errorf("failed to get resource groups: %w", err)
} }
account.DeleteResource(resource.ID)
err = transaction.SaveAccount(ctx, account) var eventsToStore []func()
for _, group := range groups {
event, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, accountID, group.ID, resourceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to save account: %w", err) return nil, fmt.Errorf("failed to remove resource from group: %w", err)
}
eventsToStore = append(eventsToStore, event)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return nil, fmt.Errorf("failed to increment network serial: %w", err)
} }
return transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) err = transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID)
if err != nil {
return nil, fmt.Errorf("failed to delete network resource: %w", err)
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, accountID, resourceID, accountID, activity.NetworkResourceDeleted, resource.EventMeta(network.Name))
})
return eventsToStore, nil
} }

View File

@ -2,11 +2,11 @@ package resources
import ( import (
"context" "context"
"errors"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
@ -20,14 +20,15 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
userID := "allowedUser" userID := "allowedUser"
networkID := "testNetworkId" networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err) require.NoError(t, err)
@ -40,14 +41,15 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
userID := "invalidUser" userID := "invalidUser"
networkID := "testNetworkId" networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.Error(t, err) require.Error(t, err)
@ -59,14 +61,15 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
accountID := "testAccountId" accountID := "testAccountId"
userID := "allowedUser" userID := "allowedUser"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.NoError(t, err) require.NoError(t, err)
@ -78,14 +81,15 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
accountID := "testAccountId" accountID := "testAccountId"
userID := "invalidUser" userID := "invalidUser"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.Error(t, err) require.Error(t, err)
@ -100,14 +104,15 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
networkID := "testNetworkId" networkID := "testNetworkId"
resourceID := "testResourceId" resourceID := "testResourceId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err) require.NoError(t, err)
@ -121,14 +126,15 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
networkID := "testNetworkId" networkID := "testNetworkId"
resourceID := "testResourceId" resourceID := "testResourceId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err) require.Error(t, err)
@ -154,7 +160,8 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(store, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.NoError(t, err) require.NoError(t, err)
@ -179,7 +186,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(store, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@ -205,7 +213,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(store, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@ -230,7 +239,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(store, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@ -252,14 +262,15 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
Address: "1.2.3.0/24", Address: "1.2.3.0/24",
} }
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.NoError(t, err) require.NoError(t, err)
@ -283,14 +294,15 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
Address: "1.2.3.0/24", Address: "1.2.3.0/24",
} }
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@ -312,18 +324,18 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
Address: "1.2.3.0/24", Address: "1.2.3.0/24",
} }
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
require.Equal(t, errors.New("new resource name already exists"), err)
require.Nil(t, updatedResource) require.Nil(t, updatedResource)
} }
@ -341,14 +353,15 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
Address: "1.2.3.0/24", Address: "1.2.3.0/24",
} }
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@ -362,14 +375,15 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
networkID := "testNetworkId" networkID := "testNetworkId"
resourceID := "testResourceId" resourceID := "testResourceId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err) require.NoError(t, err)
@ -389,7 +403,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
permissionsManager := permissions.NewManagerMock() permissionsManager := permissions.NewManagerMock()
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
manager := NewManager(store, permissionsManager, &am) groupsManager := groups.NewManagerMock()
manager := NewManager(store, permissionsManager, groupsManager, &am)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err) require.Error(t, err)

View File

@ -6,11 +6,12 @@ import (
"net/netip" "net/netip"
"regexp" "regexp"
"github.com/rs/xid"
nbDomain "github.com/netbirdio/netbird/management/domain" nbDomain "github.com/netbirdio/netbird/management/domain"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
) )
@ -35,11 +36,12 @@ type NetworkResource struct {
Description string Description string
Type NetworkResourceType Type NetworkResourceType
Address string `gorm:"-"` Address string `gorm:"-"`
GroupIDs []string `gorm:"-"`
Domain string Domain string
Prefix netip.Prefix `gorm:"serializer:json"` Prefix netip.Prefix `gorm:"serializer:json"`
} }
func NewNetworkResource(accountID, networkID, name, description, address string) (*NetworkResource, error) { func NewNetworkResource(accountID, networkID, name, description, address string, groupIDs []string) (*NetworkResource, error) {
resourceType, domain, prefix, err := GetResourceType(address) resourceType, domain, prefix, err := GetResourceType(address)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid address: %w", err) return nil, fmt.Errorf("invalid address: %w", err)
@ -55,6 +57,7 @@ func NewNetworkResource(accountID, networkID, name, description, address string)
Address: address, Address: address,
Domain: domain, Domain: domain,
Prefix: prefix, Prefix: prefix,
GroupIDs: groupIDs,
}, nil }, nil
} }
@ -81,6 +84,7 @@ func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) {
n.Description = *req.Description n.Description = *req.Description
} }
n.Address = req.Address n.Address = req.Address
n.GroupIDs = req.Groups
} }
func (n *NetworkResource) Copy() *NetworkResource { func (n *NetworkResource) Copy() *NetworkResource {
@ -94,6 +98,7 @@ func (n *NetworkResource) Copy() *NetworkResource {
Address: n.Address, Address: n.Address,
Domain: n.Domain, Domain: n.Domain,
Prefix: n.Prefix, Prefix: n.Prefix,
GroupIDs: n.GroupIDs,
} }
} }
@ -137,6 +142,10 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network
return r return r
} }
func (n *NetworkResource) EventMeta(networkName string) map[string]any {
return map[string]any{"name": n.Name, "type": n.Type, "network_name": networkName}
}
// GetResourceType returns the type of the resource based on the address // GetResourceType returns the type of the resource based on the address
func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, error) { func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, error) {
if prefix, err := netip.ParsePrefix(address); err == nil { if prefix, err := netip.ParsePrefix(address); err == nil {

View File

@ -8,7 +8,9 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@ -21,6 +23,7 @@ type Manager interface {
GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error)
UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error)
DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error
DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, routerID string) (func(), error)
} }
type managerImpl struct { type managerImpl struct {
@ -29,6 +32,9 @@ type managerImpl struct {
accountManager s.AccountManager accountManager s.AccountManager
} }
type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
@ -80,13 +86,32 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
if network.ID != router.NetworkID {
return status.NewNetworkNotFoundError(router.NetworkID)
}
router.ID = xid.New().String() router.ID = xid.New().String()
err = m.store.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create network router: %w", err) return fmt.Errorf("failed to create network router: %w", err)
} }
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network.Name))
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
return router, nil return router, nil
@ -122,11 +147,30 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
err = m.store.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to update network router: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
if network.ID != router.NetworkID {
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router)
if err != nil {
return fmt.Errorf("failed to update network router: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network.Name))
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
return router, nil return router, nil
@ -141,12 +185,77 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
err = m.store.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) var event func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, networkID, routerID)
return err
})
if err != nil { if err != nil {
return fmt.Errorf("failed to delete network router: %w", err) return err
} }
event()
go m.accountManager.UpdateAccountPeers(ctx, accountID) go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, routerID string) (func(), error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err)
}
router, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, accountID, routerID)
if err != nil {
return nil, fmt.Errorf("failed to get network router: %w", err)
}
if router.NetworkID != networkID {
return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID)
}
err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID)
if err != nil {
return nil, fmt.Errorf("failed to delete network router: %w", err)
}
event := func() {
m.accountManager.StoreEvent(ctx, "", routerID, accountID, activity.NetworkRouterDeleted, router.EventMeta(network.Name))
}
return event, nil
}
func NewManagerMock() Manager {
return &mockManager{}
}
func (m *mockManager) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) {
return []*types.NetworkRouter{}, nil
}
func (m *mockManager) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
return map[string][]*types.NetworkRouter{}, nil
}
func (m *mockManager) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) {
return router, nil
}
func (m *mockManager) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) {
return &types.NetworkRouter{}, nil
}
func (m *mockManager) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) {
return router, nil
}
func (m *mockManager) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error {
return nil
}
func (m *mockManager) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, routerID string) (func(), error) {
return func() {}, nil
}

View File

@ -68,3 +68,7 @@ func (n *NetworkRouter) Copy() *NetworkRouter {
Metric: n.Metric, Metric: n.Metric,
} }
} }
func (n *NetworkRouter) EventMeta(networkName string) map[string]any {
return map[string]any{"network_name": networkName}
}

View File

@ -22,7 +22,7 @@ func NewNetwork(accountId, name, description string) *Network {
} }
} }
func (n *Network) ToAPIResponse(routerIDs []string, resourceIDs []string, routingPeersCount int) *api.Network { func (n *Network) ToAPIResponse(routerIDs []string, resourceIDs []string, routingPeersCount int, policyIDs []string) *api.Network {
return &api.Network{ return &api.Network{
Id: n.ID, Id: n.ID,
Name: n.Name, Name: n.Name,
@ -30,6 +30,7 @@ func (n *Network) ToAPIResponse(routerIDs []string, resourceIDs []string, routin
Routers: routerIDs, Routers: routerIDs,
Resources: resourceIDs, Resources: resourceIDs,
RoutingPeersCount: routingPeersCount, RoutingPeersCount: routingPeersCount,
Policies: policyIDs,
} }
} }
@ -49,3 +50,7 @@ func (n *Network) Copy() *Network {
Description: n.Description, Description: n.Description,
} }
} }
func (n *Network) EventMeta() map[string]any {
return map[string]any{"name": n.Name}
}

View File

@ -9,13 +9,14 @@ import (
"testing" "testing"
"time" "time"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@ -2501,7 +2502,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
peerE := account.GetPeer("peerE") peerE := account.GetPeer("peerE")
router1 := getNetworkRouterByID(account, "router1") router1 := getNetworkRouterByID(account, "router1")
route1 := getNetworkResourceByID(account, "resource1").ToRoute(peerE, router1) route1 := getNetworkResourceByID(account, "resource1").ToRoute(peerE, router1)
policies := account.GetPoliciesForNetworkResourceRoute(route1) policies := account.GetPoliciesForNetworkResource(string(route1.ID))
assert.Len(t, policies, 1, "resource1 should have exactly 1 policy applied directly") assert.Len(t, policies, 1, "resource1 should have exactly 1 policy applied directly")
// Test case: Resource2 is applied to an access control group (dev), // Test case: Resource2 is applied to an access control group (dev),
@ -2509,20 +2510,20 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
peerA := account.GetPeer("peerA") peerA := account.GetPeer("peerA")
router2 := getNetworkRouterByID(account, "router2") router2 := getNetworkRouterByID(account, "router2")
route2 := getNetworkResourceByID(account, "resource2").ToRoute(peerA, router2) route2 := getNetworkResourceByID(account, "resource2").ToRoute(peerA, router2)
policies = account.GetPoliciesForNetworkResourceRoute(route2) policies = account.GetPoliciesForNetworkResource(string(route2.ID))
assert.Len(t, policies, 1, "resource2 should have exactly 1 policy applied via access control group") assert.Len(t, policies, 1, "resource2 should have exactly 1 policy applied via access control group")
// Test case: Resource3 is not applied to any access control group or policy // Test case: Resource3 is not applied to any access control group or policy
router3 := getNetworkRouterByID(account, "router3") router3 := getNetworkRouterByID(account, "router3")
route3 := getNetworkResourceByID(account, "resource3").ToRoute(peerE, router3) route3 := getNetworkResourceByID(account, "resource3").ToRoute(peerE, router3)
policies = account.GetPoliciesForNetworkResourceRoute(route3) policies = account.GetPoliciesForNetworkResource(string(route3.ID))
assert.Len(t, policies, 0, "resource3 should have no policies applied") assert.Len(t, policies, 0, "resource3 should have no policies applied")
// Test case: Resource4 is applied to the access control groups (restrictQA and unrestrictedQA), // Test case: Resource4 is applied to the access control groups (restrictQA and unrestrictedQA),
// which is part of the destination in the policies (policyResource3 and policyResource4) // which is part of the destination in the policies (policyResource3 and policyResource4)
router4 := getNetworkRouterByID(account, "router4") router4 := getNetworkRouterByID(account, "router4")
route4 := getNetworkResourceByID(account, "resource4").ToRoute(peerA, router4) route4 := getNetworkResourceByID(account, "resource4").ToRoute(peerA, router4)
policies = account.GetPoliciesForNetworkResourceRoute(route4) policies = account.GetPoliciesForNetworkResource(string(route4.ID))
assert.Len(t, policies, 2, "resource4 should have exactly 2 policy applied via access control groups") assert.Len(t, policies, 2, "resource4 should have exactly 2 policy applied via access control groups")
}) })

View File

@ -178,3 +178,11 @@ func NewPermissionDeniedError() error {
func NewPermissionValidationError(err error) error { func NewPermissionValidationError(err error) error {
return Errorf(PermissionDenied, "failed to vlidate user permissions: %s", err) return Errorf(PermissionDenied, "failed to vlidate user permissions: %s", err)
} }
func NewResourceNotPartOfNetworkError(resourceID, networkID string) error {
return Errorf(BadRequest, "resource %s is not part of the network %s", resourceID, networkID)
}
func NewRouterNotPartOfNetworkError(routerID, networkID string) error {
return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID)
}

View File

@ -588,6 +588,25 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
return groups, nil return groups, nil
} }
func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) {
var groups []*types.Group
likePattern := `%"ID":"` + resourceID + `"%`
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("resources LIKE ?", likePattern).
Find(&groups)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, result.Error
}
return groups, nil
}
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
var accounts []types.Account var accounts []types.Account
result := s.db.Find(&accounts) result := s.db.Find(&accounts)
@ -1019,6 +1038,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
return nil return nil
} }
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var group types.Group var group types.Group
result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group)
@ -1044,6 +1064,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
return nil return nil
} }
// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
var group types.Group var group types.Group
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
@ -1070,6 +1091,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
return nil return nil
} }
// AddResourceToGroup adds a resource to a group. Method always needs to run n a transaction
func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error { func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error {
var group types.Group var group types.Group
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
@ -1096,6 +1118,32 @@ func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, gro
return nil return nil
} }
// RemoveResourceFromGroup removes a resource from a group. Method always needs to run in a transaction
func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error {
var group types.Group
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID)
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
}
for i, res := range group.Resources {
if res.ID == resourceID {
group.Resources = append(group.Resources[:i], group.Resources[i+1:]...)
break
}
}
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group: %s", err)
}
return nil
}
// GetUserPeers retrieves peers for a user. // GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID) return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID)

View File

@ -2494,7 +2494,7 @@ func TestSqlStore_SaveNetworkResource(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
networkID := "ct286bi7qv930dsrrug0" networkID := "ct286bi7qv930dsrrug0"
netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com") netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com", []string{})
require.NoError(t, err) require.NoError(t, err)
err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource) err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource)
@ -2529,3 +2529,35 @@ func TestSqlStore_DeleteNetworkResource(t *testing.T) {
require.Equal(t, status.NotFound, sErr.Type()) require.Equal(t, status.NotFound, sErr.Type())
require.Nil(t, netResource) require.Nil(t, netResource)
} }
func TestSqlStore_AddAndRemoveResourceFromGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
resourceId := "ctc4nci7qv9061u6ilfg"
groupID := "cs1tnh0hhcjnqoiuebeg"
res := &types.Resource{
ID: resourceId,
Type: "host",
}
err = store.AddResourceToGroup(context.Background(), accountID, groupID, res)
require.NoError(t, err)
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
require.NoError(t, err)
require.Contains(t, group.Resources, *res)
groups, err := store.GetResourceGroups(context.Background(), LockingStrengthShare, accountID, resourceId)
require.NoError(t, err)
require.Len(t, groups, 1)
err = store.RemoveResourceFromGroup(context.Background(), accountID, groupID, res.ID)
require.NoError(t, err)
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
require.NoError(t, err)
require.NotContains(t, group.Resources, *res)
}

View File

@ -74,7 +74,8 @@ type Store interface {
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*types.Group, error) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error
@ -99,6 +100,7 @@ type Store interface {
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)

View File

@ -1251,7 +1251,7 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peerI
routes := a.getRoutingPeerNetworkResourcesRoutes(ctx, peerID) routes := a.getRoutingPeerNetworkResourcesRoutes(ctx, peerID)
for _, route := range routes { for _, route := range routes {
resourceAppliedPolicies := a.GetPoliciesForNetworkResourceRoute(route) resourceAppliedPolicies := a.GetPoliciesForNetworkResource(string(route.ID))
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
rules := a.getRouteFirewallRules(ctx, peerID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) rules := a.getRouteFirewallRules(ctx, peerID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
@ -1261,13 +1261,13 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peerI
return routesFirewallRules return routesFirewallRules
} }
// getNetworkResourceGroups retrieves all groups associated with the given network resource route. // getNetworkResourceGroups retrieves all groups associated with the given network resource.
func (a *Account) getNetworkResourceGroups(route *route.Route) []*Group { func (a *Account) getNetworkResourceGroups(resourceID string) []*Group {
var networkResourceGroups []*Group var networkResourceGroups []*Group
for _, group := range a.Groups { for _, group := range a.Groups {
for _, resource := range group.Resources { for _, resource := range group.Resources {
if resource.ID == string(route.ID) { if resource.ID == resourceID {
networkResourceGroups = append(networkResourceGroups, group) networkResourceGroups = append(networkResourceGroups, group)
} }
} }
@ -1304,13 +1304,13 @@ func (a *Account) getNetworkResources(networkID string) []*resourceTypes.Network
return resources return resources
} }
// GetPoliciesForNetworkResourceRoute retrieves the list of policies that apply to a specific network resource route. // GetPoliciesForNetworkResource retrieves the list of policies that apply to a specific network resource.
// A policy is deemed applicable if its destination groups include any of the given network resource groups // A policy is deemed applicable if its destination groups include any of the given network resource groups
// or if its destination resource explicitly matches the provided route. // or if its destination resource explicitly matches the provided resource.
func (a *Account) GetPoliciesForNetworkResourceRoute(route *route.Route) []*Policy { func (a *Account) GetPoliciesForNetworkResource(resourceId string) []*Policy {
var resourceAppliedPolicies []*Policy var resourceAppliedPolicies []*Policy
networkResourceGroups := a.getNetworkResourceGroups(route) networkResourceGroups := a.getNetworkResourceGroups(resourceId)
for _, policy := range a.Policies { for _, policy := range a.Policies {
if !policy.Enabled { if !policy.Enabled {
@ -1329,7 +1329,7 @@ func (a *Account) GetPoliciesForNetworkResourceRoute(route *route.Route) []*Poli
} }
} }
if rule.DestinationResource.ID == string(route.ID) { if rule.DestinationResource.ID == resourceId {
resourceAppliedPolicies = append(resourceAppliedPolicies, policy) resourceAppliedPolicies = append(resourceAppliedPolicies, policy)
} }
} }
@ -1338,12 +1338,31 @@ func (a *Account) GetPoliciesForNetworkResourceRoute(route *route.Route) []*Poli
return resourceAppliedPolicies return resourceAppliedPolicies
} }
func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string {
networkResources := a.getNetworkResources(networkID)
policieIDs := map[string]struct{}{}
for _, resource := range networkResources {
resourceAppliedPolicies := a.GetPoliciesForNetworkResource(resource.ID)
for _, policy := range resourceAppliedPolicies {
policieIDs[policy.ID] = struct{}{}
}
}
result := make([]string, 0, len(policieIDs))
for id := range policieIDs {
result = append(result, id)
}
return result
}
// getNetworkResourcesRoutes convert the network resources list to routes list. // getNetworkResourcesRoutes convert the network resources list to routes list.
func (a *Account) getNetworkResourcesRoutes(resources []*resourceTypes.NetworkResource, router *routerTypes.NetworkRouter, peer *nbpeer.Peer) []*route.Route { func (a *Account) getNetworkResourcesRoutes(resources []*resourceTypes.NetworkResource, router *routerTypes.NetworkRouter, peer *nbpeer.Peer) []*route.Route {
routes := make([]*route.Route, 0, len(resources)) routes := make([]*route.Route, 0, len(resources))
for _, resource := range resources { for _, resource := range resources {
resourceRoute := resource.ToRoute(peer, router) resourceRoute := resource.ToRoute(peer, router)
resourceAppliedPolicies := a.GetPoliciesForNetworkResourceRoute(resourceRoute) resourceAppliedPolicies := a.GetPoliciesForNetworkResource(string(resourceRoute.ID))
// distribute the resource routes only if there is policy applied to it // distribute the resource routes only if there is policy applied to it
if len(resourceAppliedPolicies) > 0 { if len(resourceAppliedPolicies) > 0 {