mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-25 15:28:38 +01:00
322 lines
10 KiB
Go
322 lines
10 KiB
Go
|
package networks
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
|
||
|
"github.com/gorilla/mux"
|
||
|
|
||
|
s "github.com/netbirdio/netbird/management/server"
|
||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||
|
"github.com/netbirdio/netbird/management/server/networks"
|
||
|
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||
|
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||
|
"github.com/netbirdio/netbird/management/server/networks/types"
|
||
|
"github.com/netbirdio/netbird/management/server/status"
|
||
|
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||
|
)
|
||
|
|
||
|
// handler is a handler that returns networks of the account
|
||
|
type handler struct {
|
||
|
networksManager networks.Manager
|
||
|
resourceManager resources.Manager
|
||
|
routerManager routers.Manager
|
||
|
accountManager s.AccountManager
|
||
|
|
||
|
groupsManager groups.Manager
|
||
|
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||
|
}
|
||
|
|
||
|
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
|
||
|
addRouterEndpoints(routerManager, extractFromToken, authCfg, router)
|
||
|
addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router)
|
||
|
|
||
|
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg)
|
||
|
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
|
||
|
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
|
||
|
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
|
||
|
router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS")
|
||
|
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
|
||
|
}
|
||
|
|
||
|
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler {
|
||
|
return &handler{
|
||
|
networksManager: networksManager,
|
||
|
resourceManager: resourceManager,
|
||
|
routerManager: routerManager,
|
||
|
groupsManager: groupsManager,
|
||
|
accountManager: accountManager,
|
||
|
extractFromToken: extractFromToken,
|
||
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||
|
jwtclaims.WithAudience(authCfg.Audience),
|
||
|
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||
|
),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
|
||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||
|
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
groups, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account))
|
||
|
}
|
||
|
|
||
|
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||
|
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var req api.NetworkRequest
|
||
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||
|
if err != nil {
|
||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
network := &types.Network{}
|
||
|
network.FromAPIRequest(&req)
|
||
|
|
||
|
network.AccountID = accountID
|
||
|
network, err = h.networksManager.CreateNetwork(r.Context(), userID, network)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
policyIDs := account.GetPoliciesAppliedInNetwork(network.ID)
|
||
|
|
||
|
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs))
|
||
|
}
|
||
|
|
||
|
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||
|
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
vars := mux.Vars(r)
|
||
|
networkID := vars["networkId"]
|
||
|
if len(networkID) == 0 {
|
||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
policyIDs := account.GetPoliciesAppliedInNetwork(networkID)
|
||
|
|
||
|
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
|
||
|
}
|
||
|
|
||
|
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||
|
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
vars := mux.Vars(r)
|
||
|
networkID := vars["networkId"]
|
||
|
if len(networkID) == 0 {
|
||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var req api.NetworkRequest
|
||
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||
|
if err != nil {
|
||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
network := &types.Network{}
|
||
|
network.FromAPIRequest(&req)
|
||
|
|
||
|
network.ID = networkID
|
||
|
network.AccountID = accountID
|
||
|
network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
policyIDs := account.GetPoliciesAppliedInNetwork(networkID)
|
||
|
|
||
|
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
|
||
|
}
|
||
|
|
||
|
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||
|
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
vars := mux.Vars(r)
|
||
|
networkID := vars["networkId"]
|
||
|
if len(networkID) == 0 {
|
||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID)
|
||
|
if err != nil {
|
||
|
util.WriteError(r.Context(), err, w)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||
|
}
|
||
|
|
||
|
func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, networkID string) ([]string, []string, int, error) {
|
||
|
resources, err := h.resourceManager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
||
|
if err != nil {
|
||
|
return nil, nil, 0, fmt.Errorf("failed to get resources in network: %w", err)
|
||
|
}
|
||
|
|
||
|
var resourceIDs []string
|
||
|
for _, resource := range resources {
|
||
|
resourceIDs = append(resourceIDs, resource.ID)
|
||
|
}
|
||
|
|
||
|
routers, err := h.routerManager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID)
|
||
|
if err != nil {
|
||
|
return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err)
|
||
|
}
|
||
|
|
||
|
groups, err := h.groupsManager.GetAllGroups(ctx, accountID, userID)
|
||
|
if err != nil {
|
||
|
return nil, nil, 0, fmt.Errorf("failed to get groups: %w", err)
|
||
|
}
|
||
|
|
||
|
peerCounter := 0
|
||
|
var routerIDs []string
|
||
|
for _, router := range routers {
|
||
|
routerIDs = append(routerIDs, router.ID)
|
||
|
if router.Peer != "" {
|
||
|
peerCounter++
|
||
|
}
|
||
|
if len(router.PeerGroups) > 0 {
|
||
|
for _, groupID := range router.PeerGroups {
|
||
|
peerCounter += len(groups[groupID].Peers)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return routerIDs, resourceIDs, peerCounter, nil
|
||
|
}
|
||
|
|
||
|
func (h *handler) generateNetworkResponse(networks []*types.Network, routers map[string][]*routerTypes.NetworkRouter, resourceIDs map[string][]string, groups map[string]*nbtypes.Group, account *nbtypes.Account) []*api.Network {
|
||
|
var networkResponse []*api.Network
|
||
|
for _, network := range networks {
|
||
|
routerIDs, peerCounter := getRouterIDs(network, routers, groups)
|
||
|
policyIDs := account.GetPoliciesAppliedInNetwork(network.ID)
|
||
|
networkResponse = append(networkResponse, network.ToAPIResponse(routerIDs, resourceIDs[network.ID], peerCounter, policyIDs))
|
||
|
}
|
||
|
return networkResponse
|
||
|
}
|
||
|
|
||
|
func getRouterIDs(network *types.Network, routers map[string][]*routerTypes.NetworkRouter, groups map[string]*nbtypes.Group) ([]string, int) {
|
||
|
routerIDs := []string{}
|
||
|
peerCounter := 0
|
||
|
for _, router := range routers[network.ID] {
|
||
|
routerIDs = append(routerIDs, router.ID)
|
||
|
if router.Peer != "" {
|
||
|
peerCounter++
|
||
|
}
|
||
|
if len(router.PeerGroups) > 0 {
|
||
|
for _, groupID := range router.PeerGroups {
|
||
|
group, ok := groups[groupID]
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
peerCounter += len(group.Peers)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return routerIDs, peerCounter
|
||
|
}
|