netbird/management/server/policy.go
bcmmbaga 32d1b2d602
Retrieve policy groups and posture checks once for validation
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-12 18:53:10 +03:00

671 lines
19 KiB
Go

package server
import (
"context"
_ "embed"
"strconv"
"strings"
"github.com/netbirdio/netbird/management/proto"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
// PolicyUpdateOperationType operation type
type PolicyUpdateOperationType int
// PolicyTrafficActionType action type for the firewall
type PolicyTrafficActionType string
// PolicyRuleProtocolType type of traffic
type PolicyRuleProtocolType string
// PolicyRuleDirection direction of traffic
type PolicyRuleDirection string
const (
// PolicyTrafficActionAccept indicates that the traffic is accepted
PolicyTrafficActionAccept = PolicyTrafficActionType("accept")
// PolicyTrafficActionDrop indicates that the traffic is dropped
PolicyTrafficActionDrop = PolicyTrafficActionType("drop")
)
const (
// PolicyRuleProtocolALL type of traffic
PolicyRuleProtocolALL = PolicyRuleProtocolType("all")
// PolicyRuleProtocolTCP type of traffic
PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp")
// PolicyRuleProtocolUDP type of traffic
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
// PolicyRuleProtocolICMP type of traffic
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
)
const (
// PolicyRuleFlowDirect allows traffic from source to destination
PolicyRuleFlowDirect = PolicyRuleDirection("direct")
// PolicyRuleFlowBidirect allows traffic to both directions
PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect")
)
const (
// DefaultRuleName is a name for the Default rule that is created for every account
DefaultRuleName = "Default"
// DefaultRuleDescription is a description for the Default rule that is created for every account
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
// DefaultPolicyName is a name for the Default policy that is created for every account
DefaultPolicyName = "Default"
// DefaultPolicyDescription is a description for the Default policy that is created for every account
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
)
const (
firewallRuleDirectionIN = 0
firewallRuleDirectionOUT = 1
)
// PolicyUpdateOperation operation object with type and values to be applied
type PolicyUpdateOperation struct {
Type PolicyUpdateOperationType
Values []string
}
// RulePortRange represents a range of ports for a firewall rule.
type RulePortRange struct {
Start uint16
End uint16
}
// PolicyRule is the metadata of the policy
type PolicyRule struct {
// ID of the policy rule
ID string `gorm:"primaryKey"`
// PolicyID is a reference to Policy that this object belongs
PolicyID string `json:"-" gorm:"index"`
// Name of the rule visible in the UI
Name string
// Description of the rule visible in the UI
Description string
// Enabled status of rule in the system
Enabled bool
// Action policy accept or drops packets
Action PolicyTrafficActionType
// Destinations policy destination groups
Destinations []string `gorm:"serializer:json"`
// Sources policy source groups
Sources []string `gorm:"serializer:json"`
// Bidirectional define if the rule is applicable in both directions, sources, and destinations
Bidirectional bool
// Protocol type of the traffic
Protocol PolicyRuleProtocolType
// Ports or it ranges list
Ports []string `gorm:"serializer:json"`
// PortRanges a list of port ranges.
PortRanges []RulePortRange `gorm:"serializer:json"`
}
// Copy returns a copy of a policy rule
func (pm *PolicyRule) Copy() *PolicyRule {
rule := &PolicyRule{
ID: pm.ID,
PolicyID: pm.PolicyID,
Name: pm.Name,
Description: pm.Description,
Enabled: pm.Enabled,
Action: pm.Action,
Destinations: make([]string, len(pm.Destinations)),
Sources: make([]string, len(pm.Sources)),
Bidirectional: pm.Bidirectional,
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
copy(rule.PortRanges, pm.PortRanges)
return rule
}
// Policy of the Rego query
type Policy struct {
// ID of the policy'
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name of the Policy
Name string
// Description of the policy visible in the UI
Description string
// Enabled status of the policy
Enabled bool
// Rules of the policy
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"`
// SourcePostureChecks are ID references to Posture checks for policy source groups
SourcePostureChecks []string `gorm:"serializer:json"`
}
// Copy returns a copy of the policy.
func (p *Policy) Copy() *Policy {
c := &Policy{
ID: p.ID,
AccountID: p.AccountID,
Name: p.Name,
Description: p.Description,
Enabled: p.Enabled,
Rules: make([]*PolicyRule, len(p.Rules)),
SourcePostureChecks: make([]string, len(p.SourcePostureChecks)),
}
for i, r := range p.Rules {
c.Rules[i] = r.Copy()
}
copy(c.SourcePostureChecks, p.SourcePostureChecks)
return c
}
// EventMeta returns activity event meta related to this policy
func (p *Policy) EventMeta() map[string]any {
return map[string]any{"name": p.Name}
}
// UpgradeAndFix different version of policies to latest version
func (p *Policy) UpgradeAndFix() {
for _, r := range p.Rules {
// start migrate from version v0.20.3
if r.Protocol == "" {
r.Protocol = PolicyRuleProtocolALL
}
if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional {
r.Bidirectional = true
}
// -- v0.20.4
}
}
// ruleGroups returns a list of all groups referenced in the policy's rules,
// including sources and destinations.
func (p *Policy) ruleGroups() []string {
groups := make([]string, 0)
for _, rule := range p.Rules {
groups = append(groups, rule.Sources...)
groups = append(groups, rule.Destinations...)
}
return groups
}
// FirewallRule is a rule of the firewall.
type FirewallRule struct {
// PeerIP of the peer
PeerIP string
// Direction of the traffic
Direction int
// Action of the traffic
Action string
// Protocol of the traffic
Protocol string
// Port of the traffic
Port string
}
// getPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
generateResources(rule, destinationPeers, firewallRuleDirectionIN)
}
if peerInDestinations {
generateResources(rule, sourcePeers, firewallRuleDirectionOUT)
}
}
if peerInSources {
generateResources(rule, destinationPeers, firewallRuleDirectionOUT)
}
if peerInDestinations {
generateResources(rule, sourcePeers, firewallRuleDirectionIN)
}
}
}
return getAccumulatedResources()
}
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
//
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
peers := make([]*nbpeer.Peer, 0)
all, err := a.GetGroupAll()
if err != nil {
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
all = &nbgroup.Group{}
}
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers {
if peer == nil {
continue
}
if _, ok := peersExists[peer.ID]; !ok {
peers = append(peers, peer)
peersExists[peer.ID] = struct{}{}
}
fr := FirewallRule{
PeerIP: peer.IP.String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
}
if isAll {
fr.PeerIP = "0.0.0.0"
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok {
continue
}
rulesExists[ruleID] = struct{}{}
if len(rule.Ports) == 0 {
rules = append(rules, &fr)
continue
}
for _, port := range rule.Ports {
pr := fr // clone rule and add set new port
pr.Port = port
rules = append(rules, &pr)
}
}
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules
}
}
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
}
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
var isUpdate = policy.ID != ""
var updateAccountPeers bool
var action = activity.PolicyAdded
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
saveFunc := transaction.CreatePolicy
if isUpdate {
action = activity.PolicyUpdated
saveFunc = transaction.SavePolicy
}
return saveFunc(ctx, LockingStrengthUpdate, policy)
})
if err != nil {
return nil, err
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return policy, nil
}
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var policy *Policy
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
if err != nil {
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// ListPolicies from the store.
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
}
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) {
if isUpdate {
existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
if err != nil {
return false, err
}
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
}
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
}
// validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error {
if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
if err != nil {
return err
}
} else {
policy.ID = xid.New().String()
policy.AccountID = accountID
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
if err != nil {
return err
}
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
if err != nil {
return err
}
for i, rule := range policy.Rules {
ruleCopy := rule.Copy()
if ruleCopy.ID == "" {
ruleCopy.ID = xid.New().String()
ruleCopy.PolicyID = policy.ID
}
ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources)
ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations)
policy.Rules[i] = ruleCopy
}
if policy.SourcePostureChecks != nil {
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
}
return nil
}
// getAllPeersFromGroups for given peer ID and list of groups
//
// Returns a list of peers from specified groups that pass specified posture checks
// and a boolean indicating if the supplied peer ID exists within these groups.
//
// Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups {
group, ok := a.Groups[g]
if !ok {
continue
}
for _, p := range group.Peers {
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
}
}
return filteredPeers, peerInGroups
}
// validatePostureChecksOnPeer validates the posture checks on a peer
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
peer, ok := a.Peers[peerID]
if !ok && peer == nil {
return false
}
for _, postureChecksID := range sourcePostureChecksID {
postureChecks := a.getPostureChecks(postureChecksID)
if postureChecks == nil {
continue
}
for _, check := range postureChecks.GetChecks() {
isValid, err := check.Check(ctx, *peer)
if err != nil {
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
}
if !isValid {
return false
}
}
}
return true
}
func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
for _, postureChecks := range a.PostureChecks {
if postureChecks.ID == postureChecksID {
return postureChecks
}
}
return nil
}
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
validIDs := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
if _, exists := postureChecks[id]; exists {
validIDs = append(validIDs, id)
}
}
return validIDs
}
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
validIDs := make([]string, 0, len(groupIDs))
for _, id := range groupIDs {
if _, exists := groups[id]; exists {
validIDs = append(validIDs, id)
}
}
return validIDs
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
}
return result
}