mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-23 22:38:36 +01:00
3bfa26b13b
Converts rules to Rego policies and allow users to write raw policies to set up connectivity and firewall on the clients.
327 lines
8.3 KiB
Go
327 lines
8.3 KiB
Go
package http
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/rs/xid"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/management/server"
|
|
"github.com/netbirdio/netbird/management/server/http/api"
|
|
"github.com/netbirdio/netbird/management/server/http/util"
|
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
|
"github.com/netbirdio/netbird/management/server/status"
|
|
)
|
|
|
|
// Policies is a handler that returns policy of the account
|
|
type Policies struct {
|
|
accountManager server.AccountManager
|
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
|
}
|
|
|
|
// NewPoliciesHandler creates a new Policies handler
|
|
func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Policies {
|
|
return &Policies{
|
|
accountManager: accountManager,
|
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
|
jwtclaims.WithAudience(authCfg.Audience),
|
|
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
|
),
|
|
}
|
|
}
|
|
|
|
// GetAllPolicies list for the account
|
|
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
|
claims := h.claimsExtractor.FromRequestContext(r)
|
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
|
if err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id)
|
|
if err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
util.WriteJSONObject(w, accountPolicies)
|
|
}
|
|
|
|
// UpdatePolicy handles update to a policy identified by a given ID
|
|
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
|
claims := h.claimsExtractor.FromRequestContext(r)
|
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
|
if err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
vars := mux.Vars(r)
|
|
policyID := vars["id"]
|
|
if len(policyID) == 0 {
|
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
|
return
|
|
}
|
|
|
|
policyIdx := -1
|
|
for i, policy := range account.Policies {
|
|
if policy.ID == policyID {
|
|
policyIdx = i
|
|
break
|
|
}
|
|
}
|
|
if policyIdx < 0 {
|
|
util.WriteError(status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
|
return
|
|
}
|
|
|
|
var req api.PutApiPoliciesIdJSONRequestBody
|
|
err = json.NewDecoder(r.Body).Decode(&req)
|
|
if err != nil {
|
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
|
return
|
|
}
|
|
|
|
if req.Name == "" {
|
|
util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
|
|
return
|
|
}
|
|
|
|
policy := server.Policy{
|
|
ID: policyID,
|
|
Name: req.Name,
|
|
Enabled: req.Enabled,
|
|
Description: req.Description,
|
|
Query: req.Query,
|
|
}
|
|
if req.Rules != nil {
|
|
for _, r := range req.Rules {
|
|
pr := server.PolicyRule{
|
|
Destinations: groupMinimumsToStrings(account, r.Destinations),
|
|
Sources: groupMinimumsToStrings(account, r.Sources),
|
|
Name: r.Name,
|
|
}
|
|
pr.Enabled = r.Enabled
|
|
if r.Description != nil {
|
|
pr.Description = *r.Description
|
|
}
|
|
if r.Id != nil {
|
|
pr.ID = *r.Id
|
|
}
|
|
switch r.Action {
|
|
case api.PolicyRuleActionAccept:
|
|
pr.Action = server.PolicyTrafficActionAccept
|
|
case api.PolicyRuleActionDrop:
|
|
pr.Action = server.PolicyTrafficActionDrop
|
|
default:
|
|
util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w)
|
|
return
|
|
}
|
|
policy.Rules = append(policy.Rules, &pr)
|
|
}
|
|
}
|
|
if err := policy.UpdateQueryFromRules(); err != nil {
|
|
log.Errorf("failed to update policy query: %v", err)
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
if err = h.accountManager.SavePolicy(account.Id, user.Id, &policy); err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
util.WriteJSONObject(w, toPolicyResponse(account, &policy))
|
|
}
|
|
|
|
// CreatePolicy handles policy creation request
|
|
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
|
claims := h.claimsExtractor.FromRequestContext(r)
|
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
|
if err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
var req api.PostApiPoliciesJSONRequestBody
|
|
err = json.NewDecoder(r.Body).Decode(&req)
|
|
if err != nil {
|
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
|
return
|
|
}
|
|
|
|
if req.Name == "" {
|
|
util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
|
|
return
|
|
}
|
|
|
|
policy := &server.Policy{
|
|
ID: xid.New().String(),
|
|
Name: req.Name,
|
|
Enabled: req.Enabled,
|
|
Description: req.Description,
|
|
Query: req.Query,
|
|
}
|
|
|
|
if req.Rules != nil {
|
|
for _, r := range req.Rules {
|
|
pr := server.PolicyRule{
|
|
ID: xid.New().String(),
|
|
Destinations: groupMinimumsToStrings(account, r.Destinations),
|
|
Sources: groupMinimumsToStrings(account, r.Sources),
|
|
Name: r.Name,
|
|
}
|
|
pr.Enabled = r.Enabled
|
|
if r.Description != nil {
|
|
pr.Description = *r.Description
|
|
}
|
|
switch r.Action {
|
|
case api.PolicyRuleActionAccept:
|
|
pr.Action = server.PolicyTrafficActionAccept
|
|
case api.PolicyRuleActionDrop:
|
|
pr.Action = server.PolicyTrafficActionDrop
|
|
default:
|
|
util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w)
|
|
return
|
|
}
|
|
policy.Rules = append(policy.Rules, &pr)
|
|
}
|
|
}
|
|
if err := policy.UpdateQueryFromRules(); err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
if err = h.accountManager.SavePolicy(account.Id, user.Id, policy); err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
util.WriteJSONObject(w, toPolicyResponse(account, policy))
|
|
}
|
|
|
|
// DeletePolicy handles policy deletion request
|
|
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
|
claims := h.claimsExtractor.FromRequestContext(r)
|
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
|
if err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
aID := account.Id
|
|
|
|
vars := mux.Vars(r)
|
|
policyID := vars["id"]
|
|
if len(policyID) == 0 {
|
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
|
return
|
|
}
|
|
|
|
if err = h.accountManager.DeletePolicy(aID, policyID, user.Id); err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
util.WriteJSONObject(w, "")
|
|
}
|
|
|
|
// GetPolicy handles a group Get request identified by ID
|
|
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
|
claims := h.claimsExtractor.FromRequestContext(r)
|
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
|
if err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
switch r.Method {
|
|
case http.MethodGet:
|
|
vars := mux.Vars(r)
|
|
policyID := vars["id"]
|
|
if len(policyID) == 0 {
|
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
|
return
|
|
}
|
|
|
|
policy, err := h.accountManager.GetPolicy(account.Id, policyID, user.Id)
|
|
if err != nil {
|
|
util.WriteError(err, w)
|
|
return
|
|
}
|
|
|
|
util.WriteJSONObject(w, toPolicyResponse(account, policy))
|
|
default:
|
|
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
|
|
}
|
|
}
|
|
|
|
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
|
|
cache := make(map[string]api.GroupMinimum)
|
|
ap := &api.Policy{
|
|
Id: policy.ID,
|
|
Name: policy.Name,
|
|
Description: policy.Description,
|
|
Enabled: policy.Enabled,
|
|
Query: policy.Query,
|
|
}
|
|
if len(policy.Rules) == 0 {
|
|
return ap
|
|
}
|
|
|
|
for _, r := range policy.Rules {
|
|
rule := api.PolicyRule{
|
|
Id: &r.ID,
|
|
Name: r.Name,
|
|
Enabled: r.Enabled,
|
|
Description: &r.Description,
|
|
}
|
|
for _, gid := range r.Sources {
|
|
_, ok := cache[gid]
|
|
if ok {
|
|
continue
|
|
}
|
|
if group, ok := account.Groups[gid]; ok {
|
|
minimum := api.GroupMinimum{
|
|
Id: group.ID,
|
|
Name: group.Name,
|
|
PeersCount: len(group.Peers),
|
|
}
|
|
rule.Sources = append(rule.Sources, minimum)
|
|
cache[gid] = minimum
|
|
}
|
|
}
|
|
for _, gid := range r.Destinations {
|
|
cachedMinimum, ok := cache[gid]
|
|
if ok {
|
|
rule.Destinations = append(rule.Destinations, cachedMinimum)
|
|
continue
|
|
}
|
|
if group, ok := account.Groups[gid]; ok {
|
|
minimum := api.GroupMinimum{
|
|
Id: group.ID,
|
|
Name: group.Name,
|
|
PeersCount: len(group.Peers),
|
|
}
|
|
rule.Destinations = append(rule.Destinations, minimum)
|
|
cache[gid] = minimum
|
|
}
|
|
}
|
|
ap.Rules = append(ap.Rules, rule)
|
|
}
|
|
return ap
|
|
}
|
|
|
|
func groupMinimumsToStrings(account *server.Account, gm []api.GroupMinimum) []string {
|
|
result := make([]string, 0, len(gm))
|
|
for _, gm := range gm {
|
|
if _, ok := account.Groups[gm.Id]; ok {
|
|
continue
|
|
}
|
|
result = append(result, gm.Id)
|
|
}
|
|
return result
|
|
}
|