Optimize ACL performance (#994)

* Optimize rules with All groups

* Use IP sets in ACLs (nftables implementation)

* Fix squash rule when we receive optimized rules list from management
This commit is contained in:
Givi Khojanashvili 2023-07-18 13:12:50 +04:00 committed by GitHub
parent 7ebe58f20a
commit e69ec6ab6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 727 additions and 114 deletions

View File

@ -51,6 +51,7 @@ type Manager interface {
dPort *Port, dPort *Port,
direction RuleDirection, direction RuleDirection,
action Action, action Action,
ipsetName string,
comment string, comment string,
) (Rule, error) ) (Rule, error)
@ -60,5 +61,8 @@ type Manager interface {
// Reset firewall to the default state // Reset firewall to the default state
Reset() error Reset() error
// Flush the changes to firewall controller
Flush() error
// TODO: migrate routemanager firewal actions to this interface // TODO: migrate routemanager firewal actions to this interface
} }

View File

@ -92,6 +92,7 @@ func (m *Manager) AddFiltering(
dPort *fw.Port, dPort *fw.Port,
direction fw.RuleDirection, direction fw.RuleDirection,
action fw.Action, action fw.Action,
ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) (fw.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
@ -202,6 +203,9 @@ func (m *Manager) Reset() error {
return nil return nil
} }
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// reset firewall chain, clear it and drop it // reset firewall chain, clear it and drop it
func (m *Manager) reset(client *iptables.IPTables, table string) error { func (m *Manager) reset(client *iptables.IPTables, table string) error {
ok, err := client.ChainExists(table, ChainInputFilterName) ok, err := client.ChainExists(table, ChainInputFilterName)

View File

@ -68,7 +68,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) { t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
@ -81,7 +81,7 @@ func TestIptablesManager(t *testing.T) {
Values: []int{8043: 8046}, Values: []int{8043: 8046},
} }
rule2, err = manager.AddFiltering( rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range") ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
@ -107,7 +107,7 @@ func TestIptablesManager(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}} port := &fw.Port{Values: []int{5353}}
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic") _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset() err = manager.Reset()
@ -167,9 +167,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@ -6,12 +6,14 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/google/uuid" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall" fw "github.com/netbirdio/netbird/client/firewall"
@ -29,11 +31,14 @@ const (
FilterOutputChainName = "netbird-acl-output-filter" FilterOutputChainName = "netbird-acl-output-filter"
) )
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
conn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn
tableIPv4 *nftables.Table tableIPv4 *nftables.Table
tableIPv6 *nftables.Table tableIPv6 *nftables.Table
@ -43,6 +48,10 @@ type Manager struct {
filterInputChainIPv6 *nftables.Chain filterInputChainIPv6 *nftables.Chain
filterOutputChainIPv6 *nftables.Chain filterOutputChainIPv6 *nftables.Chain
rulesetManager *rulesetManager
setRemovedIPs map[string]struct{}
setRemoved map[string]*nftables.Set
wgIface iFaceMapper wgIface iFaceMapper
} }
@ -54,8 +63,23 @@ type iFaceMapper interface {
// Create nftables firewall manager // Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for booth type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
m := &Manager{ m := &Manager{
conn: &nftables.Conn{}, rConn: &nftables.Conn{},
sConn: sConn,
rulesetManager: newRuleManager(),
setRemovedIPs: map[string]struct{}{},
setRemoved: map[string]*nftables.Set{},
wgIface: wgIface, wgIface: wgIface,
} }
@ -77,6 +101,7 @@ func (m *Manager) AddFiltering(
dPort *fw.Port, dPort *fw.Port,
direction fw.RuleDirection, direction fw.RuleDirection,
action fw.Action, action fw.Action,
ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) (fw.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
@ -84,6 +109,7 @@ func (m *Manager) AddFiltering(
var ( var (
err error err error
ipset *nftables.Set
table *nftables.Table table *nftables.Table
chain *nftables.Chain chain *nftables.Chain
) )
@ -107,6 +133,46 @@ func (m *Manager) AddFiltering(
return nil, err return nil, err
} }
rawIP := ip.To4()
if rawIP == nil {
rawIP = ip.To16()
}
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
if ipsetName != "" {
// if we already have set with given name, just add ip to the set
// and return rule with new ID in other case let's create rule
// with fresh created set and set element
var isSetNew bool
ipset, err := m.rConn.GetSetByName(table, ipsetName)
if err != nil {
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
isSetNew = true
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
if !isSetNew {
// if we already have nftables rules with set for given direction
// just add new rule to the ruleset and return new fw.Rule object
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
return m.rulesetManager.addRule(ruleset, rawIP)
}
// if ipset exists but it is not linked to rule for given direction
// create new rule for direction and bind ipset to it later
}
}
ifaceKey := expr.MetaKeyIIFNAME ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT { if direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME ifaceKey = expr.MetaKeyOIFNAME
@ -146,39 +212,47 @@ func (m *Manager) AddFiltering(
}) })
} }
// don't use IP matching if IP is ip 0.0.0.0 // check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
if s := ip.String(); s != "0.0.0.0" && s != "::" { // in that case not add IP match expression into the rule definition
if !bytes.HasPrefix(anyIP, rawIP) {
// source address position // source address position
var adrLen, adrOffset uint32 addrLen := uint32(len(rawIP))
if ip.To4() == nil { addrOffset := uint32(12)
adrLen = 16 if addrLen == 16 {
adrOffset = 8 addrOffset = 8
} else {
adrLen = 4
adrOffset = 12
} }
// change to destination address position if need // change to destination address position if need
if direction == fw.RuleDirectionOUT { if direction == fw.RuleDirectionOUT {
adrOffset += adrLen addrOffset += addrLen
} }
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: adrOffset, Offset: addrOffset,
Len: adrLen, Len: addrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: add.AsSlice(),
}, },
) )
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
},
)
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipsetName,
SetID: ipset.ID,
},
)
}
} }
if sPort != nil && len(sPort.Values) != 0 { if sPort != nil && len(sPort.Values) != 0 {
@ -219,39 +293,76 @@ func (m *Manager) AddFiltering(
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
} }
id := uuid.New().String() userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
userData := []byte(strings.Join([]string{id, comment}, " "))
_ = m.conn.InsertRule(&nftables.Rule{ rule := m.rConn.InsertRule(&nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
Position: 0, Position: 0,
Exprs: expressions, Exprs: expressions,
UserData: userData, UserData: userData,
}) })
if err := m.rConn.Flush(); err != nil {
if err := m.conn.Flush(); err != nil { return nil, fmt.Errorf("flush insert rule: %v", err)
return nil, err
} }
list, err := m.conn.GetRules(table, chain) ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
if err != nil { return m.rulesetManager.addRule(ruleset, rawIP)
return nil, err }
// getRulesetID returns ruleset ID based on given parameters
func (m *Manager) getRulesetID(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
}
rulesetID += ":"
rulesetID += strconv.Itoa(int(action))
if ipsetName == "" {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipsetName + rulesetID
}
// createSet in given table by name
func (m *Manager) createSet(
table *nftables.Table,
rawIP []byte,
name string,
) (*nftables.Set, error) {
keyType := nftables.TypeIPAddr
if len(rawIP) == 16 {
keyType = nftables.TypeIP6Addr
}
// else we create new ipset and continue creating rule
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: keyType,
} }
// Add the rule to the chain if err := m.rConn.AddSet(ipset, nil); err != nil {
rule := &Rule{id: id} return nil, fmt.Errorf("create set: %v", err)
for _, r := range list {
if bytes.Equal(r.UserData, userData) {
rule.Rule = r
break
}
}
if rule.Rule == nil {
return nil, fmt.Errorf("rule not found")
} }
return rule, nil if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
} }
// chain returns the chain for the given IP address with specific settings // chain returns the chain for the given IP address with specific settings
@ -315,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
} }
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) { func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.conn.ListTablesOfFamily(family) tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of tables: %w", err) return nil, fmt.Errorf("list of tables: %w", err)
} }
@ -326,7 +437,11 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
} }
} }
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
return table, nil
} }
func (m *Manager) createChainIfNotExists( func (m *Manager) createChainIfNotExists(
@ -341,7 +456,7 @@ func (m *Manager) createChainIfNotExists(
return nil, err return nil, err
} }
chains, err := m.conn.ListChainsOfTableFamily(family) chains, err := m.rConn.ListChainsOfTableFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of chains: %w", err) return nil, fmt.Errorf("list of chains: %w", err)
} }
@ -362,7 +477,7 @@ func (m *Manager) createChainIfNotExists(
Policy: &polAccept, Policy: &polAccept,
} }
chain = m.conn.AddChain(chain) chain = m.rConn.AddChain(chain)
ifaceKey := expr.MetaKeyIIFNAME ifaceKey := expr.MetaKeyIIFNAME
shiftDSTAddr := 0 shiftDSTAddr := 0
@ -429,7 +544,7 @@ func (m *Manager) createChainIfNotExists(
) )
} }
_ = m.conn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
Exprs: expressions, Exprs: expressions,
@ -444,12 +559,13 @@ func (m *Manager) createChainIfNotExists(
}, },
&expr.Verdict{Kind: expr.VerdictDrop}, &expr.Verdict{Kind: expr.VerdictDrop},
} }
_ = m.conn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
Exprs: expressions, Exprs: expressions,
}) })
if err := m.conn.Flush(); err != nil {
if err := m.rConn.Flush(); err != nil {
return nil, err return nil, err
} }
@ -458,16 +574,58 @@ func (m *Manager) createChainIfNotExists(
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error { func (m *Manager) DeleteRule(rule fw.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
nativeRule, ok := rule.(*Rule) nativeRule, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
} }
if err := m.conn.DelRule(nativeRule.Rule); err != nil { if nativeRule.nftRule == nil {
return err return nil
} }
return m.conn.Flush() if nativeRule.nftSet != nil {
// call twice of delete set element raises error
// so we need to check if element is already removed
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
if _, ok := m.setRemovedIPs[key]; !ok {
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
if err != nil {
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
return err
}
m.setRemovedIPs[key] = struct{}{}
}
}
if m.rulesetManager.deleteRule(nativeRule) {
// deleteRule indicates that we still have IP in the ruleset
// it means we should not remove the nftables rule but need to update set
// so we prepare IP to be removed from set on the next flush call
return nil
}
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return err
}
nativeRule.nftRule = nil
if nativeRule.nftSet != nil {
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
}
nativeRule.nftSet = nil
}
return nil
} }
// Reset firewall to the default state // Reset firewall to the default state
@ -475,27 +633,116 @@ func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
chains, err := m.conn.ListChains() chains, err := m.rConn.ListChains()
if err != nil { if err != nil {
return fmt.Errorf("list of chains: %w", err) return fmt.Errorf("list of chains: %w", err)
} }
for _, c := range chains { for _, c := range chains {
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.conn.DelChain(c) m.rConn.DelChain(c)
} }
} }
tables, err := m.conn.ListTables() tables, err := m.rConn.ListTables()
if err != nil { if err != nil {
return fmt.Errorf("list of tables: %w", err) return fmt.Errorf("list of tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == FilterTableName { if t.Name == FilterTableName {
m.conn.DelTable(t) m.rConn.DelTable(t)
} }
} }
return m.conn.Flush() return m.rConn.Flush()
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.flushWithBackoff(); err != nil {
return err
}
// set must be removed after flush rule changes
// otherwise we will get error
for _, s := range m.setRemoved {
m.rConn.FlushSet(s)
m.rConn.DelSet(s)
}
if len(m.setRemoved) > 0 {
if err := m.flushWithBackoff(); err != nil {
return err
}
}
m.setRemovedIPs = map[string]struct{}{}
m.setRemoved = map[string]*nftables.Set{}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
}
return nil
}
func (m *Manager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime = backoffTime * 2
continue
}
break
}
return
}
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
if table == nil || chain == nil {
return nil
}
list, err := m.rConn.GetRules(table, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) != 0 {
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
log.Errorf("failed to set rule handle: %v", err)
}
}
}
return nil
} }
func encodePort(port fw.Port) []byte { func encodePort(port fw.Port) []byte {

View File

@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset() err = manager.Reset()
@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) {
fw.RuleDirectionIN, fw.RuleDirectionIN,
fw.ActionDrop, fw.ActionDrop,
"", "",
"",
) )
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// test expectations: // test expectations:
// 1) regular rule // 1) regular rule
// 2) "accept extra routed traffic rule" for the interface // 2) "accept extra routed traffic rule" for the interface
@ -135,6 +140,9 @@ func TestNftablesManager(t *testing.T) {
err = manager.DeleteRule(rule) err = manager.DeleteRule(rule)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// test expectations: // test expectations:
@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(); err != nil { if err := manager.Reset(); err != nil {
@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
err = manager.Flush()
require.NoError(t, err, "failed to flush")
}
} }
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax)) t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
}) })
} }

View File

@ -6,11 +6,14 @@ import (
// Rule to handle management of rules // Rule to handle management of rules
type Rule struct { type Rule struct {
*nftables.Rule nftRule *nftables.Rule
id string nftSet *nftables.Set
ruleID string
ip []byte
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *Rule) GetRuleID() string {
return r.id return r.ruleID
} }

View File

@ -0,0 +1,115 @@
package nftables
import (
"bytes"
"fmt"
"github.com/google/nftables"
"github.com/rs/xid"
)
// nftRuleset links native firewall rule and ipset to ACL generated rules
type nftRuleset struct {
nftRule *nftables.Rule
nftSet *nftables.Set
issuedRules map[string]*Rule
rulesetID string
}
type rulesetManager struct {
rulesets map[string]*nftRuleset
nftSetName2rulesetID map[string]string
issuedRuleID2rulesetID map[string]string
}
func newRuleManager() *rulesetManager {
return &rulesetManager{
rulesets: map[string]*nftRuleset{},
nftSetName2rulesetID: map[string]string{},
issuedRuleID2rulesetID: map[string]string{},
}
}
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
ruleset, ok := r.rulesets[rulesetID]
return ruleset, ok
}
func (r *rulesetManager) createRuleset(
rulesetID string,
nftRule *nftables.Rule,
nftSet *nftables.Set,
) *nftRuleset {
ruleset := nftRuleset{
rulesetID: rulesetID,
nftRule: nftRule,
nftSet: nftSet,
issuedRules: map[string]*Rule{},
}
r.rulesets[ruleset.rulesetID] = &ruleset
if nftSet != nil {
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
}
return &ruleset
}
func (r *rulesetManager) addRule(
ruleset *nftRuleset,
ip []byte,
) (*Rule, error) {
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
return nil, fmt.Errorf("ruleset not found")
}
rule := Rule{
nftRule: ruleset.nftRule,
nftSet: ruleset.nftSet,
ruleID: xid.New().String(),
ip: ip,
}
ruleset.issuedRules[rule.ruleID] = &rule
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
return &rule, nil
}
// deleteRule from ruleset and returns true if contains other rules
func (r *rulesetManager) deleteRule(rule *Rule) bool {
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
if !ok {
return false
}
ruleset := r.rulesets[rulesetID]
if ruleset.nftRule == nil {
return false
}
delete(r.issuedRuleID2rulesetID, rule.ruleID)
delete(ruleset.issuedRules, rule.ruleID)
if len(ruleset.issuedRules) == 0 {
delete(r.rulesets, ruleset.rulesetID)
if rule.nftSet != nil {
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
}
return false
}
return true
}
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
//
// This is important to do, because after we add rule to the nftables we can't update it until
// we set correct handle value to it.
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
split := bytes.Split(nftRule.UserData, []byte(" "))
ruleset, ok := r.rulesets[string(split[0])]
if !ok {
return fmt.Errorf("ruleset not found")
}
*ruleset.nftRule = *nftRule
return nil
}

View File

@ -0,0 +1,122 @@
package nftables
import (
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
)
func TestRulesetManager_createRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{
UserData: []byte(rulesetID),
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
require.NotNil(t, ruleset, "createRuleset() failed")
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
}
func TestRulesetManager_addRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
require.NotEqual(t, rule.ruleID, "ruleID is empty")
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
ruleset2 := &nftRuleset{
rulesetID: "ruleset-2",
}
_, err = rulesetManager.addRule(ruleset2, ip)
require.Error(t, err, "addRule() should have failed")
}
func TestRulesetManager_deleteRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
ip2 := []byte("192.168.1.1")
rule2, err := rulesetManager.addRule(ruleset, ip2)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule2, "rule should not be nil")
hasNext := rulesetManager.deleteRule(rule)
require.True(t, hasNext, "deleteRule() should have returned true")
// Check that the rule is no longer in the manager.
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
hasNext = rulesetManager.deleteRule(rule2)
require.False(t, hasNext, "deleteRule() should have returned false")
}
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.0.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
nftRuleCopy := nftRule
nftRuleCopy.Handle = 2
nftRuleCopy.UserData = []byte(rulesetID)
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
require.NoError(t, err, "setNftRuleHandle() failed")
// check correct work with references
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
}
func TestRulesetManager_getRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
nftSet := nftables.Set{
ID: 2,
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
require.NotNil(t, ruleset, "createRuleset() failed")
find, ok := rulesetManager.getRuleset(rulesetID)
require.True(t, ok, "getRuleset() failed")
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
_, ok = rulesetManager.getRuleset("does-not-exist")
require.False(t, ok, "getRuleset() failed")
}

View File

@ -84,6 +84,7 @@ func (m *Manager) AddFiltering(
dPort *fw.Port, dPort *fw.Port,
direction fw.RuleDirection, direction fw.RuleDirection,
action fw.Action, action fw.Action,
ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) (fw.Rule, error) {
r := Rule{ r := Rule{
@ -181,6 +182,9 @@ func (m *Manager) Reset() error {
return nil return nil
} }
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool { func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.dropFilter(packetData, m.outgoingRules, false) return m.dropFilter(packetData, m.outgoingRules, false)

View File

@ -63,7 +63,7 @@ func TestManagerAddFiltering(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -98,7 +98,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -111,7 +111,7 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop action = fw.ActionDrop
comment = "Test rule 2" comment = "Test rule 2"
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -236,7 +236,7 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment) _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -274,7 +274,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule" comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment) _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@ -390,9 +390,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@ -33,9 +33,22 @@ type Manager interface {
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
type DefaultManager struct { type DefaultManager struct {
manager firewall.Manager manager firewall.Manager
rulesPairs map[string][]firewall.Rule ipsetCounter int
mutex sync.Mutex rulesPairs map[string][]firewall.Rule
mutex sync.Mutex
}
type ipsetInfo struct {
name string
ipCount int
}
func newDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}
} }
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
@ -61,6 +74,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
return return
} }
defer func() {
if err := d.manager.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}()
rules, squashedProtocols := d.squashAcceptRules(networkMap) rules, squashedProtocols := d.squashAcceptRules(networkMap)
enableSSH := (networkMap.PeerConfig != nil && enableSSH := (networkMap.PeerConfig != nil &&
@ -108,8 +127,32 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
applyFailed := false applyFailed := false
newRulePairs := make(map[string][]firewall.Rule) newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
// calculate which IP's can be grouped in by which ipset
// to do that we use rule selector (which is just rule properties without IP's)
for _, r := range rules { for _, r := range rules {
pairID, rulePair, err := d.protoRuleToFirewallRule(r) selector := d.getRuleGroupingSelector(r)
ipset, ok := ipsetByRuleSelectors[selector]
if !ok {
ipset = &ipsetInfo{}
}
ipset.ipCount++
ipsetByRuleSelectors[selector] = ipset
}
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
ipsetName := ""
if ipset.name == "" {
d.ipsetCounter++
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
}
ipsetName = ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil { if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err) log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
applyFailed = true applyFailed = true
@ -154,7 +197,10 @@ func (d *DefaultManager) Stop() {
} }
} }
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (string, []firewall.Rule, error) { func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (string, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP) ip := net.ParseIP(r.PeerIP)
if ip == nil { if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
@ -190,9 +236,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
var err error var err error
switch r.Direction { switch r.Direction {
case mgmProto.FirewallRule_IN: case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, "") rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.FirewallRule_OUT: case mgmProto.FirewallRule_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, "") rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
} }
@ -205,9 +251,17 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
return ruleID, rules, nil return ruleID, rules, nil
} }
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) { func (d *DefaultManager) addInRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment) rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@ -217,7 +271,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
return rules, nil return rules, nil
} }
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment) rule, err = d.manager.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@ -225,9 +280,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
return append(rules, rule), nil return append(rules, rule), nil
} }
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) { func (d *DefaultManager) addOutRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment) rule, err := d.manager.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@ -237,7 +300,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
return rules, nil return rules, nil
} }
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment) rule, err = d.manager.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
@ -282,6 +346,10 @@ func (d *DefaultManager) squashAcceptRules(
in := protoMatch{} in := protoMatch{}
out := protoMatch{} out := protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not. // this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list. // We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if: // But we zeroed the IP's for protocol if:
@ -298,12 +366,22 @@ func (d *DefaultManager) squashAcceptRules(
if _, ok := protocols[r.Protocol]; !ok { if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = map[string]int{} protocols[r.Protocol] = map[string]int{}
} }
match := protocols[r.Protocol]
if _, ok := match[r.PeerIP]; ok { // special case, when we recieve this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return return
} }
match[r.PeerIP] = i
ipset := protocols[r.Protocol]
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
} }
for i, r := range networkMap.FirewallRules { for i, r := range networkMap.FirewallRules {
@ -324,9 +402,6 @@ func (d *DefaultManager) squashAcceptRules(
mgmProto.FirewallRule_UDP, mgmProto.FirewallRule_UDP,
} }
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) { squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
for _, protocol := range protocolOrders { for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
@ -382,6 +457,11 @@ func (d *DefaultManager) squashAcceptRules(
return append(rules, squashedRules...), squashedProtocols return append(rules, squashedRules...), squashedProtocols
} }
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol { func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
switch protocol { switch protocol {
case mgmProto.FirewallRule_TCP: case mgmProto.FirewallRule_TCP:

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"runtime" "runtime"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
) )
@ -18,10 +17,7 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &DefaultManager{ return newDefaultManager(fm), nil
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}, nil
} }
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }

View File

@ -29,8 +29,5 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
} }
} }
return &DefaultManager{ return newDefaultManager(fm), nil
manager: fm,
rulesPairs: make(map[string][]firewall.Rule),
}, nil
} }

View File

@ -2,9 +2,11 @@ package server
import ( import (
_ "embed" _ "embed"
"fmt" "strconv"
"strings" "strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -240,7 +242,15 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun
peersExists := make(map[string]struct{}) peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0) rules := make([]*FirewallRule, 0)
peers := make([]*Peer, 0) peers := make([]*Peer, 0)
all, err := a.GetGroupAll()
if err != nil {
log.Errorf("failed to get group all: %v", err)
all = &Group{}
}
return func(rule *PolicyRule, groupPeers []*Peer, direction int) { return func(rule *PolicyRule, groupPeers []*Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers { for _, peer := range groupPeers {
if peer == nil { if peer == nil {
continue continue
@ -250,29 +260,33 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun
peersExists[peer.ID] = struct{}{} peersExists[peer.ID] = struct{}{}
} }
fwRule := FirewallRule{ fr := FirewallRule{
PeerIP: peer.IP.String(), PeerIP: peer.IP.String(),
Direction: direction, Direction: direction,
Action: string(rule.Action), Action: string(rule.Action),
Protocol: string(rule.Protocol), Protocol: string(rule.Protocol),
} }
ruleID := fmt.Sprintf("%s%d", peer.ID+peer.IP.String(), direction) if isAll {
ruleID += string(rule.Protocol) + string(rule.Action) + strings.Join(rule.Ports, ",") fr.PeerIP = "0.0.0.0"
}
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
if _, ok := rulesExists[ruleID]; ok { if _, ok := rulesExists[ruleID]; ok {
continue continue
} }
rulesExists[ruleID] = struct{}{} rulesExists[ruleID] = struct{}{}
if len(rule.Ports) == 0 { if len(rule.Ports) == 0 {
rules = append(rules, &fwRule) rules = append(rules, &fr)
continue continue
} }
for _, port := range rule.Ports { for _, port := range rule.Ports {
addRule := fwRule pr := fr // clone rule and add set new port
addRule.Port = port pr.Port = port
rules = append(rules, &addRule) rules = append(rules, &pr)
} }
} }
}, func() ([]*Peer, []*FirewallRule) { }, func() ([]*Peer, []*FirewallRule) {

View File

@ -126,6 +126,20 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.Contains(t, peers, account.Peers["peerF"]) assert.Contains(t, peers, account.Peers["peerF"])
epectedFirewallRules := []*FirewallRule{ epectedFirewallRules := []*FirewallRule{
{
PeerIP: "0.0.0.0",
Direction: firewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "0.0.0.0",
Direction: firewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{ {
PeerIP: "100.65.14.88", PeerIP: "100.65.14.88",
Direction: firewallRuleDirectionIN, Direction: firewallRuleDirectionIN,