netbird/management/server/policy.go

603 lines
17 KiB
Go

package server
import (
"context"
_ "embed"
"slices"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
// PolicyUpdateOperationType operation type
type PolicyUpdateOperationType int
// PolicyTrafficActionType action type for the firewall
type PolicyTrafficActionType string
// PolicyRuleProtocolType type of traffic
type PolicyRuleProtocolType string
// PolicyRuleDirection direction of traffic
type PolicyRuleDirection string
const (
// PolicyTrafficActionAccept indicates that the traffic is accepted
PolicyTrafficActionAccept = PolicyTrafficActionType("accept")
// PolicyTrafficActionDrop indicates that the traffic is dropped
PolicyTrafficActionDrop = PolicyTrafficActionType("drop")
)
const (
// PolicyRuleProtocolALL type of traffic
PolicyRuleProtocolALL = PolicyRuleProtocolType("all")
// PolicyRuleProtocolTCP type of traffic
PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp")
// PolicyRuleProtocolUDP type of traffic
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
// PolicyRuleProtocolICMP type of traffic
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
)
const (
// PolicyRuleFlowDirect allows traffic from source to destination
PolicyRuleFlowDirect = PolicyRuleDirection("direct")
// PolicyRuleFlowBidirect allows traffic to both directions
PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect")
)
const (
// DefaultRuleName is a name for the Default rule that is created for every account
DefaultRuleName = "Default"
// DefaultRuleDescription is a description for the Default rule that is created for every account
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
// DefaultPolicyName is a name for the Default policy that is created for every account
DefaultPolicyName = "Default"
// DefaultPolicyDescription is a description for the Default policy that is created for every account
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
)
const (
firewallRuleDirectionIN = 0
firewallRuleDirectionOUT = 1
)
// PolicyUpdateOperation operation object with type and values to be applied
type PolicyUpdateOperation struct {
Type PolicyUpdateOperationType
Values []string
}
// RulePortRange represents a range of ports for a firewall rule.
type RulePortRange struct {
Start uint16
End uint16
}
// PolicyRule is the metadata of the policy
type PolicyRule struct {
// ID of the policy rule
ID string `gorm:"primaryKey"`
// PolicyID is a reference to Policy that this object belongs
PolicyID string `json:"-" gorm:"index"`
// Name of the rule visible in the UI
Name string
// Description of the rule visible in the UI
Description string
// Enabled status of rule in the system
Enabled bool
// Action policy accept or drops packets
Action PolicyTrafficActionType
// Destinations policy destination groups
Destinations []string `gorm:"serializer:json"`
// Sources policy source groups
Sources []string `gorm:"serializer:json"`
// Bidirectional define if the rule is applicable in both directions, sources, and destinations
Bidirectional bool
// Protocol type of the traffic
Protocol PolicyRuleProtocolType
// Ports or it ranges list
Ports []string `gorm:"serializer:json"`
// PortRanges a list of port ranges.
PortRanges []RulePortRange `gorm:"serializer:json"`
}
// Copy returns a copy of a policy rule
func (pm *PolicyRule) Copy() *PolicyRule {
rule := &PolicyRule{
ID: pm.ID,
Name: pm.Name,
Description: pm.Description,
Enabled: pm.Enabled,
Action: pm.Action,
Destinations: make([]string, len(pm.Destinations)),
Sources: make([]string, len(pm.Sources)),
Bidirectional: pm.Bidirectional,
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
copy(rule.PortRanges, pm.PortRanges)
return rule
}
// Policy of the Rego query
type Policy struct {
// ID of the policy'
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name of the Policy
Name string
// Description of the policy visible in the UI
Description string
// Enabled status of the policy
Enabled bool
// Rules of the policy
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"`
// SourcePostureChecks are ID references to Posture checks for policy source groups
SourcePostureChecks []string `gorm:"serializer:json"`
}
// Copy returns a copy of the policy.
func (p *Policy) Copy() *Policy {
c := &Policy{
ID: p.ID,
Name: p.Name,
Description: p.Description,
Enabled: p.Enabled,
Rules: make([]*PolicyRule, len(p.Rules)),
SourcePostureChecks: make([]string, len(p.SourcePostureChecks)),
}
for i, r := range p.Rules {
c.Rules[i] = r.Copy()
}
copy(c.SourcePostureChecks, p.SourcePostureChecks)
return c
}
// EventMeta returns activity event meta related to this policy
func (p *Policy) EventMeta() map[string]any {
return map[string]any{"name": p.Name}
}
// UpgradeAndFix different version of policies to latest version
func (p *Policy) UpgradeAndFix() {
for _, r := range p.Rules {
// start migrate from version v0.20.3
if r.Protocol == "" {
r.Protocol = PolicyRuleProtocolALL
}
if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional {
r.Bidirectional = true
}
// -- v0.20.4
}
}
// ruleGroups returns a list of all groups referenced in the policy's rules,
// including sources and destinations.
func (p *Policy) ruleGroups() []string {
groups := make([]string, 0)
for _, rule := range p.Rules {
groups = append(groups, rule.Sources...)
groups = append(groups, rule.Destinations...)
}
return groups
}
// FirewallRule is a rule of the firewall.
type FirewallRule struct {
// PeerIP of the peer
PeerIP string
// Direction of the traffic
Direction int
// Action of the traffic
Action string
// Protocol of the traffic
Protocol string
// Port of the traffic
Port string
}
// getPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
generateResources(rule, destinationPeers, firewallRuleDirectionIN)
}
if peerInDestinations {
generateResources(rule, sourcePeers, firewallRuleDirectionOUT)
}
}
if peerInSources {
generateResources(rule, destinationPeers, firewallRuleDirectionOUT)
}
if peerInDestinations {
generateResources(rule, sourcePeers, firewallRuleDirectionIN)
}
}
}
return getAccumulatedResources()
}
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
//
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
peers := make([]*nbpeer.Peer, 0)
all, err := a.GetGroupAll()
if err != nil {
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
all = &nbgroup.Group{}
}
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers {
if peer == nil {
continue
}
if _, ok := peersExists[peer.ID]; !ok {
peers = append(peers, peer)
peersExists[peer.ID] = struct{}{}
}
fr := FirewallRule{
PeerIP: peer.IP.String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
}
if isAll {
fr.PeerIP = "0.0.0.0"
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok {
continue
}
rulesExists[ruleID] = struct{}{}
if len(rule.Ports) == 0 {
rules = append(rules, &fr)
continue
}
for _, port := range rule.Ports {
pr := fr // clone rule and add set new port
pr.Port = port
rules = append(rules, &pr)
}
}
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules
}
}
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
}
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
if err != nil {
return err
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
action := activity.PolicyAdded
if isUpdate {
action = activity.PolicyUpdated
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
} else {
log.WithContext(ctx).Tracef("Skipping account peers update for policy: %s", policy.ID)
}
return nil
}
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
policy, err := am.deletePolicy(account, policyID)
if err != nil {
return err
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
am.updateAccountPeers(ctx, account)
return nil
}
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
policyIdx := -1
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
}
policy := account.Policies[policyIdx]
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
return policy, nil
}
// savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
policyToSave.Rules[index] = rule
}
if policyToSave.SourcePostureChecks != nil {
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
}
if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
if policyIdx < 0 {
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
}
oldPolicy := account.Policies[policyIdx]
if !policyToSave.Enabled && !oldPolicy.Enabled {
return false, nil
}
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
// Update the existing policy
account.Policies[policyIdx] = policyToSave
return updateAccountPeers, nil
}
// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
}
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
}
return result
}
// getAllPeersFromGroups for given peer ID and list of groups
//
// Returns a list of peers from specified groups that pass specified posture checks
// and a boolean indicating if the supplied peer ID exists within these groups.
//
// Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups {
group, ok := a.Groups[g]
if !ok {
continue
}
for _, p := range group.Peers {
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
}
}
return filteredPeers, peerInGroups
}
// validatePostureChecksOnPeer validates the posture checks on a peer
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
peer, ok := a.Peers[peerID]
if !ok && peer == nil {
return false
}
for _, postureChecksID := range sourcePostureChecksID {
postureChecks := a.getPostureChecks(postureChecksID)
if postureChecks == nil {
continue
}
for _, check := range postureChecks.GetChecks() {
isValid, err := check.Check(ctx, *peer)
if err != nil {
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
}
if !isValid {
return false
}
}
}
return true
}
func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
for _, postureChecks := range a.PostureChecks {
if postureChecks.ID == postureChecksID {
return postureChecks
}
}
return nil
}
// filterValidPostureChecks filters and returns the posture check IDs from the given list
// that are valid within the provided account.
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
}
}
return result
}
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
result := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
if _, exists := account.Groups[groupID]; exists {
result = append(result, groupID)
}
}
return result
}