mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-15 03:11:02 +01:00
269 lines
7.4 KiB
Go
269 lines
7.4 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
_ "embed"
|
|
|
|
"github.com/rs/xid"
|
|
|
|
"github.com/netbirdio/netbird/management/proto"
|
|
"github.com/netbirdio/netbird/management/server/store"
|
|
"github.com/netbirdio/netbird/management/server/types"
|
|
|
|
"github.com/netbirdio/netbird/management/server/activity"
|
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
|
"github.com/netbirdio/netbird/management/server/posture"
|
|
"github.com/netbirdio/netbird/management/server/status"
|
|
)
|
|
|
|
// GetPolicy from the store
|
|
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) {
|
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if user.AccountID != accountID {
|
|
return nil, status.NewUserNotPartOfAccountError()
|
|
}
|
|
|
|
if user.IsRegularUser() {
|
|
return nil, status.NewAdminPermissionError()
|
|
}
|
|
|
|
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID)
|
|
}
|
|
|
|
// SavePolicy in the store
|
|
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) {
|
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
defer unlock()
|
|
|
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if user.AccountID != accountID {
|
|
return nil, status.NewUserNotPartOfAccountError()
|
|
}
|
|
|
|
if user.IsRegularUser() {
|
|
return nil, status.NewAdminPermissionError()
|
|
}
|
|
|
|
var isUpdate = policy.ID != ""
|
|
var updateAccountPeers bool
|
|
var action = activity.PolicyAdded
|
|
|
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
|
return err
|
|
}
|
|
|
|
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
|
return err
|
|
}
|
|
|
|
saveFunc := transaction.CreatePolicy
|
|
if isUpdate {
|
|
action = activity.PolicyUpdated
|
|
saveFunc = transaction.SavePolicy
|
|
}
|
|
|
|
return saveFunc(ctx, store.LockingStrengthUpdate, policy)
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
|
|
|
if updateAccountPeers {
|
|
am.updateAccountPeers(ctx, accountID)
|
|
}
|
|
|
|
return policy, nil
|
|
}
|
|
|
|
// DeletePolicy from the store
|
|
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
defer unlock()
|
|
|
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if user.AccountID != accountID {
|
|
return status.NewUserNotPartOfAccountError()
|
|
}
|
|
|
|
if user.IsRegularUser() {
|
|
return status.NewAdminPermissionError()
|
|
}
|
|
|
|
var policy *types.Policy
|
|
var updateAccountPeers bool
|
|
|
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
|
policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
|
return err
|
|
}
|
|
|
|
return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
|
|
|
if updateAccountPeers {
|
|
am.updateAccountPeers(ctx, accountID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListPolicies from the store.
|
|
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
|
|
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if user.AccountID != accountID {
|
|
return nil, status.NewUserNotPartOfAccountError()
|
|
}
|
|
|
|
if user.IsRegularUser() {
|
|
return nil, status.NewAdminPermissionError()
|
|
}
|
|
|
|
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
|
|
}
|
|
|
|
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
|
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
|
if isUpdate {
|
|
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if !policy.Enabled && !existingPolicy.Enabled {
|
|
return false, nil
|
|
}
|
|
|
|
hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if hasPeers {
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
|
}
|
|
|
|
// validatePolicy validates the policy and its rules.
|
|
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
|
if policy.ID != "" {
|
|
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
policy.ID = xid.New().String()
|
|
policy.AccountID = accountID
|
|
}
|
|
|
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for i, rule := range policy.Rules {
|
|
ruleCopy := rule.Copy()
|
|
if ruleCopy.ID == "" {
|
|
ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor
|
|
ruleCopy.PolicyID = policy.ID
|
|
}
|
|
|
|
ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources)
|
|
ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations)
|
|
policy.Rules[i] = ruleCopy
|
|
}
|
|
|
|
if policy.SourcePostureChecks != nil {
|
|
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
|
func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
|
|
validIDs := make([]string, 0, len(postureChecksIds))
|
|
for _, id := range postureChecksIds {
|
|
if _, exists := postureChecks[id]; exists {
|
|
validIDs = append(validIDs, id)
|
|
}
|
|
}
|
|
|
|
return validIDs
|
|
}
|
|
|
|
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
|
|
func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
|
|
validIDs := make([]string, 0, len(groupIDs))
|
|
for _, id := range groupIDs {
|
|
if _, exists := groups[id]; exists {
|
|
validIDs = append(validIDs, id)
|
|
}
|
|
}
|
|
|
|
return validIDs
|
|
}
|
|
|
|
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
|
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
|
|
result := make([]*proto.FirewallRule, len(rules))
|
|
for i := range rules {
|
|
rule := rules[i]
|
|
|
|
result[i] = &proto.FirewallRule{
|
|
PeerIP: rule.PeerIP,
|
|
Direction: getProtoDirection(rule.Direction),
|
|
Action: getProtoAction(rule.Action),
|
|
Protocol: getProtoProtocol(rule.Protocol),
|
|
Port: rule.Port,
|
|
}
|
|
}
|
|
return result
|
|
}
|