mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
ACL firewall manager fix/improvement (#970)
* ACL firewall manager fix/improvement Fix issue with rule squashing, it contained issue when calculated total amount of IPs in the Peer map (doesn't included offline peers). That why squashing not worked. Also this commit changes the rules apply behaviour. Instead policy: 1. Apply all rules from network map 2. Remove all previous applied rules We do: 1. Apply only new rules 2. Remove outdated rules Why first variant was implemented: because when you have drop policy it is important in which order order you rules are and you need totally clean previous state to apply the new. But in the release we didn't include drop policy so we can do this improvement. * Print log message about processed ACL rules
This commit is contained in:
parent
20ae540fb1
commit
c20f98c8b6
@ -1,5 +1,9 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Protocol is the protocol of the port
|
||||
type Protocol string
|
||||
|
||||
@ -28,3 +32,15 @@ type Port struct {
|
||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||
Values []int
|
||||
}
|
||||
|
||||
// String interface implementation
|
||||
func (p *Port) String() string {
|
||||
var ports string
|
||||
for _, port := range p.Values {
|
||||
if ports != "" {
|
||||
ports += ","
|
||||
}
|
||||
ports += strconv.Itoa(port)
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
@ -1,10 +1,13 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@ -42,6 +45,17 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
total := 0
|
||||
for _, pairs := range d.rulesPairs {
|
||||
total += len(pairs)
|
||||
}
|
||||
log.Infof(
|
||||
"ACL rules processed in: %v, total rules count: %d",
|
||||
time.Since(start), total)
|
||||
}()
|
||||
|
||||
if d.manager == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return
|
||||
@ -95,13 +109,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
applyFailed := false
|
||||
newRulePairs := make(map[string][]firewall.Rule)
|
||||
for _, r := range rules {
|
||||
rulePair, err := d.protoRuleToFirewallRule(r)
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r)
|
||||
if err != nil {
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
applyFailed = true
|
||||
break
|
||||
}
|
||||
newRulePairs[rulePair[0].GetRuleID()] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
if applyFailed {
|
||||
log.Error("failed to apply firewall rules, rollback ACL to previous state")
|
||||
@ -140,33 +154,38 @@ func (d *DefaultManager) Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) ([]firewall.Rule, error) {
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (string, []firewall.Rule, error) {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
|
||||
protocol := convertToFirewallProtocol(r.Protocol)
|
||||
if protocol == firewall.ProtocolUnknown {
|
||||
return nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
|
||||
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
|
||||
}
|
||||
|
||||
action := convertFirewallAction(r.Action)
|
||||
if action == firewall.ActionUnknown {
|
||||
return nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
|
||||
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if r.Port != "" {
|
||||
value, err := strconv.Atoi(r.Port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port, skipping firewall rule")
|
||||
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
||||
}
|
||||
port = &firewall.Port{
|
||||
Values: []int{value},
|
||||
}
|
||||
}
|
||||
|
||||
ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
|
||||
if rulesPair, ok := d.rulesPairs[ruleID]; ok {
|
||||
return ruleID, rulesPair, nil
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
var err error
|
||||
switch r.Direction {
|
||||
@ -175,15 +194,15 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) ([]fi
|
||||
case mgmProto.FirewallRule_OUT:
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, "")
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
d.rulesPairs[rules[0].GetRuleID()] = rules
|
||||
return rules, nil
|
||||
d.rulesPairs[ruleID] = rules
|
||||
return ruleID, rules, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
||||
@ -226,6 +245,23 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return append(rules, rule), nil
|
||||
}
|
||||
|
||||
// getRuleID() returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getRuleID(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
comment string,
|
||||
) string {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||
if port != nil {
|
||||
idStr += port.String()
|
||||
}
|
||||
|
||||
return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
|
||||
}
|
||||
|
||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
||||
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
||||
//
|
||||
@ -235,7 +271,7 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
networkMap *mgmProto.NetworkMap,
|
||||
) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
|
||||
totalIPs := 0
|
||||
for _, p := range networkMap.RemotePeers {
|
||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
||||
for range p.AllowedIps {
|
||||
totalIPs++
|
||||
}
|
||||
|
@ -55,6 +55,11 @@ func TestDefaultManager(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add extra rules", func(t *testing.T) {
|
||||
existedPairs := map[string]struct{}{}
|
||||
for id := range acl.rulesPairs {
|
||||
existedPairs[id] = struct{}{}
|
||||
}
|
||||
|
||||
// remove first rule
|
||||
networkMap.FirewallRules = networkMap.FirewallRules[1:]
|
||||
networkMap.FirewallRules = append(
|
||||
@ -67,11 +72,6 @@ func TestDefaultManager(t *testing.T) {
|
||||
},
|
||||
)
|
||||
|
||||
existedRulesID := map[string]struct{}{}
|
||||
for id := range acl.rulesPairs {
|
||||
existedRulesID[id] = struct{}{}
|
||||
}
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
|
||||
// we should have one old and one new rule in the existed rules
|
||||
@ -80,13 +80,16 @@ func TestDefaultManager(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
// check that old rules was removed
|
||||
for id := range existedRulesID {
|
||||
if _, ok := acl.rulesPairs[id]; ok {
|
||||
t.Errorf("old rule was not removed")
|
||||
return
|
||||
// check that old rule was removed
|
||||
previousCount := 0
|
||||
for id := range acl.rulesPairs {
|
||||
if _, ok := existedPairs[id]; ok {
|
||||
previousCount++
|
||||
}
|
||||
}
|
||||
if previousCount != 1 {
|
||||
t.Errorf("old rule was not removed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handle default rules", func(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user