netbird/management/server/http/policies_handler.go

373 lines
10 KiB
Go
Raw Normal View History

package http
import (
"encoding/json"
"net/http"
"strconv"
"github.com/gorilla/mux"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// Policies is a handler that returns policy of the account
type Policies struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewPoliciesHandler creates a new Policies handler
func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Policies {
return &Policies{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
policies := []*api.Policy{}
for _, policy := range accountPolicies {
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
policies = append(policies, resp)
}
util.WriteJSONObject(r.Context(), w, policies)
}
// UpdatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
policyIdx := -1
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
}
h.savePolicy(w, r, account, user, policyID)
}
// CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
h.savePolicy(w, r, account, user, "")
}
// savePolicy handles policy creation and update
func (h *Policies) savePolicy(
w http.ResponseWriter,
r *http.Request,
account *server.Account,
user *server.User,
policyID string,
) {
var req api.PutApiPoliciesPolicyIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
return
}
if len(req.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w)
return
}
if policyID == "" {
policyID = xid.New().String()
}
policy := server.Policy{
ID: policyID,
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,
}
for _, rule := range req.Rules {
pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
Name: rule.Name,
Destinations: groupMinimumsToStrings(account, rule.Destinations),
Sources: groupMinimumsToStrings(account, rule.Sources),
Bidirectional: rule.Bidirectional,
}
pr.Enabled = rule.Enabled
if rule.Description != nil {
pr.Description = *rule.Description
}
switch rule.Action {
case api.PolicyRuleUpdateActionAccept:
pr.Action = server.PolicyTrafficActionAccept
case api.PolicyRuleUpdateActionDrop:
pr.Action = server.PolicyTrafficActionDrop
default:
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w)
return
}
switch rule.Protocol {
case api.PolicyRuleUpdateProtocolAll:
pr.Protocol = server.PolicyRuleProtocolALL
case api.PolicyRuleUpdateProtocolTcp:
pr.Protocol = server.PolicyRuleProtocolTCP
case api.PolicyRuleUpdateProtocolUdp:
pr.Protocol = server.PolicyRuleProtocolUDP
case api.PolicyRuleUpdateProtocolIcmp:
pr.Protocol = server.PolicyRuleProtocolICMP
default:
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
return
}
if rule.Ports != nil && len(*rule.Ports) != 0 {
for _, v := range *rule.Ports {
if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
return
}
pr.Ports = append(pr.Ports, v)
}
}
// validate policy object
switch pr.Protocol {
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
if len(pr.Ports) != 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
return
}
if !pr.Bidirectional {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
if !pr.Bidirectional && len(pr.Ports) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
}
policy.Rules = append(policy.Rules, &pr)
}
if req.SourcePostureChecks != nil {
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
}
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(account, &policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
}
// DeletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
aID := account.Id
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
// GetPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
switch r.Method {
case http.MethodGet:
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
}
}
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
cache := make(map[string]api.GroupMinimum)
ap := &api.Policy{
Id: &policy.ID,
Name: policy.Name,
Description: policy.Description,
Enabled: policy.Enabled,
SourcePostureChecks: policy.SourcePostureChecks,
}
for _, r := range policy.Rules {
rID := r.ID
rDescription := r.Description
rule := api.PolicyRule{
Id: &rID,
Name: r.Name,
Enabled: r.Enabled,
Description: &rDescription,
Bidirectional: r.Bidirectional,
Protocol: api.PolicyRuleProtocol(r.Protocol),
Action: api.PolicyRuleAction(r.Action),
}
if len(r.Ports) != 0 {
portsCopy := r.Ports
rule.Ports = &portsCopy
}
for _, gid := range r.Sources {
_, ok := cache[gid]
if ok {
continue
}
if group, ok := account.Groups[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
rule.Sources = append(rule.Sources, minimum)
cache[gid] = minimum
}
}
for _, gid := range r.Destinations {
cachedMinimum, ok := cache[gid]
if ok {
rule.Destinations = append(rule.Destinations, cachedMinimum)
continue
}
if group, ok := account.Groups[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
rule.Destinations = append(rule.Destinations, minimum)
cache[gid] = minimum
}
}
ap.Rules = append(ap.Rules, rule)
}
return ap
}
func groupMinimumsToStrings(account *server.Account, gm []string) []string {
result := make([]string, 0, len(gm))
for _, g := range gm {
if _, ok := account.Groups[g]; !ok {
continue
}
result = append(result, g)
}
return result
}
func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
}
}
return result
}