netbird/client/internal/acl/manager.go
Zoltan Papp 006ba32086
Fix/acl for forward (#1305)
Fix ACL on routed traffic and code refactor
2023-12-08 10:48:21 +01:00

455 lines
13 KiB
Go

package acl
import (
"crypto/md5"
"encoding/hex"
"fmt"
"net"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap)
}
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
rulesPairs map[string][]firewall.Rule
mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
firewall: fm,
rulesPairs: make(map[string][]firewall.Rule),
}
}
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
//
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.mutex.Lock()
defer d.mutex.Unlock()
start := time.Now()
defer func() {
total := 0
for _, pairs := range d.rulesPairs {
total += len(pairs)
}
log.Infof(
"ACL rules processed in: %v, total rules count: %d",
time.Since(start), total)
}()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
defer func() {
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}()
rules, squashedProtocols := d.squashAcceptRules(networkMap)
enableSSH := (networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled)
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
enableSSH = enableSSH && !ok
}
if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
enableSSH = enableSSH && !ok
}
// if TCP protocol rules not squashed and SSH enabled
// we add default firewall rule which accepts connection to any peer
// in the network by SSH (TCP 22 port).
if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_TCP,
Port: strconv.Itoa(ssh.DefaultSSHPort),
})
}
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
// we have old version of management without rules handling, we should allow all traffic
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
log.Warn("this peer is connected to a NetBird Management service with an older version. Allowing all traffic from connected peers")
rules = append(rules,
&mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
},
&mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.FirewallRule_ALL,
},
)
}
newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
selector := d.getRuleGroupingSelector(r)
ipsetName, ok := ipsetByRuleSelectors[selector]
if !ok {
d.ipsetCounter++
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetByRuleSelectors[selector] = ipsetName
}
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
d.rollBack(newRulePairs)
break
}
if len(rules) > 0 {
d.rulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
}
for pairID, rules := range d.rulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete firewall rule: %v", err)
continue
}
}
delete(d.rulesPairs, pairID)
}
}
d.rulesPairs = newRulePairs
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (string, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP)
if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
}
protocol, err := convertToFirewallProtocol(r.Protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
action, err := convertFirewallAction(r.Action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
var port *firewall.Port
if r.Port != "" {
value, err := strconv.Atoi(r.Port)
if err != nil {
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
}
port = &firewall.Port{
Values: []int{value},
}
}
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
}
var rules []firewall.Rule
switch r.Direction {
case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.FirewallRule_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
if err != nil {
return "", nil, err
}
return ruleID, rules, nil
}
func (d *DefaultManager) addInRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
}
rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule...), nil
}
func (d *DefaultManager) addOutRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
}
rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule...), nil
}
// getRuleID() returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getRuleID(
ip net.IP,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
comment string,
) string {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
if port != nil {
idStr += port.String()
}
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
}
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
// to all peers in the network map to one rule which just accepts that type of the traffic.
//
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {
totalIPs++
}
}
type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
in := protoMatch{}
out := protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if:
// 1. Any of the rule has DROP action type.
// 2. Any of rule contains Port.
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
if drop {
protocols[r.Protocol] = map[string]int{}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = map[string]int{}
}
// special case, when we receive this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return
}
ipset := protocols[r.Protocol]
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
}
for i, r := range networkMap.FirewallRules {
// calculate squash for different directions
if r.Direction == mgmProto.FirewallRule_IN {
addRuleToCalculationMap(i, r, in)
} else {
addRuleToCalculationMap(i, r, out)
}
}
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
protocolOrders := []mgmProto.FirewallRuleProtocol{
mgmProto.FirewallRule_ALL,
mgmProto.FirewallRule_ICMP,
mgmProto.FirewallRule_TCP,
mgmProto.FirewallRule_UDP,
}
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
// don't squash if :
// 1. Rules not cover all peers in the network
// 2. Rules cover only one peer in the network.
continue
}
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: direction,
Action: mgmProto.FirewallRule_ACCEPT,
Protocol: protocol,
})
squashedProtocols[protocol] = struct{}{}
if protocol == mgmProto.FirewallRule_ALL {
// if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing
break
}
}
}
squash(in, mgmProto.FirewallRule_IN)
squash(out, mgmProto.FirewallRule_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
return squashedRules, squashedProtocols
}
if len(squashedRules) == 0 {
return networkMap.FirewallRules, squashedProtocols
}
var rules []*mgmProto.FirewallRule
// filter out rules which was squashed from final list
// if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules {
if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m[r.PeerIP] == i {
continue
} else if m, ok := out[r.Protocol]; ok && m[r.PeerIP] == i {
continue
}
}
rules = append(rules, r)
}
return append(rules, squashedRules...), squashedProtocols
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
}
}
}
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
switch protocol {
case mgmProto.FirewallRule_TCP:
return firewall.ProtocolTCP, nil
case mgmProto.FirewallRule_UDP:
return firewall.ProtocolUDP, nil
case mgmProto.FirewallRule_ICMP:
return firewall.ProtocolICMP, nil
case mgmProto.FirewallRule_ALL:
return firewall.ProtocolALL, nil
default:
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
}
}
func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) bool {
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
}
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
switch action {
case mgmProto.FirewallRule_ACCEPT:
return firewall.ActionAccept, nil
case mgmProto.FirewallRule_DROP:
return firewall.ActionDrop, nil
default:
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
}
}