mirror of
https://github.com/netbirdio/netbird.git
synced 2025-05-30 06:40:15 +02:00
Optimize ACL performance (#994)
* Optimize rules with All groups * Use IP sets in ACLs (nftables implementation) * Fix squash rule when we receive optimized rules list from management
This commit is contained in:
parent
7ebe58f20a
commit
e69ec6ab6a
@ -51,6 +51,7 @@ type Manager interface {
|
|||||||
dPort *Port,
|
dPort *Port,
|
||||||
direction RuleDirection,
|
direction RuleDirection,
|
||||||
action Action,
|
action Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (Rule, error)
|
) (Rule, error)
|
||||||
|
|
||||||
@ -60,5 +61,8 @@ type Manager interface {
|
|||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
Reset() error
|
Reset() error
|
||||||
|
|
||||||
|
// Flush the changes to firewall controller
|
||||||
|
Flush() error
|
||||||
|
|
||||||
// TODO: migrate routemanager firewal actions to this interface
|
// TODO: migrate routemanager firewal actions to this interface
|
||||||
}
|
}
|
||||||
|
@ -92,6 +92,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
@ -202,6 +203,9 @@ func (m *Manager) Reset() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// reset firewall chain, clear it and drop it
|
// reset firewall chain, clear it and drop it
|
||||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
ok, err := client.ChainExists(table, ChainInputFilterName)
|
||||||
|
@ -68,7 +68,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
t.Run("add first rule", func(t *testing.T) {
|
t.Run("add first rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.2")
|
ip := net.ParseIP("10.20.0.2")
|
||||||
port := &fw.Port{Values: []int{8080}}
|
port := &fw.Port{Values: []int{8080}}
|
||||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||||
@ -81,7 +81,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
Values: []int{8043: 8046},
|
Values: []int{8043: 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFiltering(
|
rule2, err = manager.AddFiltering(
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
||||||
@ -107,7 +107,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []int{5353}}
|
port := &fw.Port{Values: []int{5353}}
|
||||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic")
|
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
@ -167,9 +167,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
@ -6,12 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
"github.com/google/uuid"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall"
|
fw "github.com/netbirdio/netbird/client/firewall"
|
||||||
@ -29,11 +31,14 @@ const (
|
|||||||
FilterOutputChainName = "netbird-acl-output-filter"
|
FilterOutputChainName = "netbird-acl-output-filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||||
|
|
||||||
// Manager of iptables firewall
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
conn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
|
sConn *nftables.Conn
|
||||||
tableIPv4 *nftables.Table
|
tableIPv4 *nftables.Table
|
||||||
tableIPv6 *nftables.Table
|
tableIPv6 *nftables.Table
|
||||||
|
|
||||||
@ -43,6 +48,10 @@ type Manager struct {
|
|||||||
filterInputChainIPv6 *nftables.Chain
|
filterInputChainIPv6 *nftables.Chain
|
||||||
filterOutputChainIPv6 *nftables.Chain
|
filterOutputChainIPv6 *nftables.Chain
|
||||||
|
|
||||||
|
rulesetManager *rulesetManager
|
||||||
|
setRemovedIPs map[string]struct{}
|
||||||
|
setRemoved map[string]*nftables.Set
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,8 +63,23 @@ type iFaceMapper interface {
|
|||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||||
|
// sConn is used for creating sets and adding/removing elements from them
|
||||||
|
// it's differ then rConn (which does create new conn for each flush operation)
|
||||||
|
// and is permanent. Using same connection for booth type of operations
|
||||||
|
// overloads netlink with high amount of rules ( > 10000)
|
||||||
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
conn: &nftables.Conn{},
|
rConn: &nftables.Conn{},
|
||||||
|
sConn: sConn,
|
||||||
|
|
||||||
|
rulesetManager: newRuleManager(),
|
||||||
|
setRemovedIPs: map[string]struct{}{},
|
||||||
|
setRemoved: map[string]*nftables.Set{},
|
||||||
|
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,6 +101,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
@ -84,6 +109,7 @@ func (m *Manager) AddFiltering(
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
|
ipset *nftables.Set
|
||||||
table *nftables.Table
|
table *nftables.Table
|
||||||
chain *nftables.Chain
|
chain *nftables.Chain
|
||||||
)
|
)
|
||||||
@ -107,6 +133,46 @@ func (m *Manager) AddFiltering(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rawIP := ip.To4()
|
||||||
|
if rawIP == nil {
|
||||||
|
rawIP = ip.To16()
|
||||||
|
}
|
||||||
|
|
||||||
|
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
||||||
|
|
||||||
|
if ipsetName != "" {
|
||||||
|
// if we already have set with given name, just add ip to the set
|
||||||
|
// and return rule with new ID in other case let's create rule
|
||||||
|
// with fresh created set and set element
|
||||||
|
|
||||||
|
var isSetNew bool
|
||||||
|
ipset, err := m.rConn.GetSetByName(table, ipsetName)
|
||||||
|
if err != nil {
|
||||||
|
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
||||||
|
return nil, fmt.Errorf("get set name: %v", err)
|
||||||
|
}
|
||||||
|
isSetNew = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||||
|
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||||
|
}
|
||||||
|
if err := m.sConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSetNew {
|
||||||
|
// if we already have nftables rules with set for given direction
|
||||||
|
// just add new rule to the ruleset and return new fw.Rule object
|
||||||
|
|
||||||
|
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
||||||
|
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||||
|
}
|
||||||
|
// if ipset exists but it is not linked to rule for given direction
|
||||||
|
// create new rule for direction and bind ipset to it later
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
ifaceKey := expr.MetaKeyIIFNAME
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
ifaceKey = expr.MetaKeyOIFNAME
|
ifaceKey = expr.MetaKeyOIFNAME
|
||||||
@ -146,39 +212,47 @@ func (m *Manager) AddFiltering(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
||||||
if s := ip.String(); s != "0.0.0.0" && s != "::" {
|
// in that case not add IP match expression into the rule definition
|
||||||
|
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||||
// source address position
|
// source address position
|
||||||
var adrLen, adrOffset uint32
|
addrLen := uint32(len(rawIP))
|
||||||
if ip.To4() == nil {
|
addrOffset := uint32(12)
|
||||||
adrLen = 16
|
if addrLen == 16 {
|
||||||
adrOffset = 8
|
addrOffset = 8
|
||||||
} else {
|
|
||||||
adrLen = 4
|
|
||||||
adrOffset = 12
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// change to destination address position if need
|
// change to destination address position if need
|
||||||
if direction == fw.RuleDirectionOUT {
|
if direction == fw.RuleDirectionOUT {
|
||||||
adrOffset += adrLen
|
addrOffset += addrLen
|
||||||
}
|
}
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
expressions = append(expressions,
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
Offset: adrOffset,
|
Offset: addrOffset,
|
||||||
Len: adrLen,
|
Len: addrLen,
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: add.AsSlice(),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
// add individual IP for match if no ipset defined
|
||||||
|
if ipset == nil {
|
||||||
|
expressions = append(expressions,
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: rawIP,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
expressions = append(expressions,
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: ipsetName,
|
||||||
|
SetID: ipset.ID,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) != 0 {
|
if sPort != nil && len(sPort.Values) != 0 {
|
||||||
@ -219,39 +293,76 @@ func (m *Manager) AddFiltering(
|
|||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
id := uuid.New().String()
|
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
||||||
userData := []byte(strings.Join([]string{id, comment}, " "))
|
|
||||||
|
|
||||||
_ = m.conn.InsertRule(&nftables.Rule{
|
rule := m.rConn.InsertRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Position: 0,
|
Position: 0,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
UserData: userData,
|
UserData: userData,
|
||||||
})
|
})
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
if err := m.conn.Flush(); err != nil {
|
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
list, err := m.conn.GetRules(table, chain)
|
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
||||||
if err != nil {
|
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||||
return nil, err
|
}
|
||||||
|
|
||||||
|
// getRulesetID returns ruleset ID based on given parameters
|
||||||
|
func (m *Manager) getRulesetID(
|
||||||
|
ip net.IP,
|
||||||
|
proto fw.Protocol,
|
||||||
|
sPort *fw.Port,
|
||||||
|
dPort *fw.Port,
|
||||||
|
direction fw.RuleDirection,
|
||||||
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
|
) string {
|
||||||
|
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||||
|
if sPort != nil {
|
||||||
|
rulesetID += sPort.String()
|
||||||
|
}
|
||||||
|
rulesetID += ":"
|
||||||
|
if dPort != nil {
|
||||||
|
rulesetID += dPort.String()
|
||||||
|
}
|
||||||
|
rulesetID += ":"
|
||||||
|
rulesetID += strconv.Itoa(int(action))
|
||||||
|
if ipsetName == "" {
|
||||||
|
return "ip:" + ip.String() + rulesetID
|
||||||
|
}
|
||||||
|
return "set:" + ipsetName + rulesetID
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSet in given table by name
|
||||||
|
func (m *Manager) createSet(
|
||||||
|
table *nftables.Table,
|
||||||
|
rawIP []byte,
|
||||||
|
name string,
|
||||||
|
) (*nftables.Set, error) {
|
||||||
|
keyType := nftables.TypeIPAddr
|
||||||
|
if len(rawIP) == 16 {
|
||||||
|
keyType = nftables.TypeIP6Addr
|
||||||
|
}
|
||||||
|
// else we create new ipset and continue creating rule
|
||||||
|
ipset := &nftables.Set{
|
||||||
|
Name: name,
|
||||||
|
Table: table,
|
||||||
|
Dynamic: true,
|
||||||
|
KeyType: keyType,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the rule to the chain
|
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||||
rule := &Rule{id: id}
|
return nil, fmt.Errorf("create set: %v", err)
|
||||||
for _, r := range list {
|
|
||||||
if bytes.Equal(r.UserData, userData) {
|
|
||||||
rule.Rule = r
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if rule.Rule == nil {
|
|
||||||
return nil, fmt.Errorf("rule not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return rule, nil
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush created set: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// chain returns the chain for the given IP address with specific settings
|
// chain returns the chain for the given IP address with specific settings
|
||||||
@ -315,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
||||||
tables, err := m.conn.ListTablesOfFamily(family)
|
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of tables: %w", err)
|
return nil, fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
@ -326,7 +437,11 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil
|
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return table, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createChainIfNotExists(
|
func (m *Manager) createChainIfNotExists(
|
||||||
@ -341,7 +456,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
chains, err := m.conn.ListChainsOfTableFamily(family)
|
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of chains: %w", err)
|
return nil, fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
}
|
||||||
@ -362,7 +477,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
Policy: &polAccept,
|
Policy: &polAccept,
|
||||||
}
|
}
|
||||||
|
|
||||||
chain = m.conn.AddChain(chain)
|
chain = m.rConn.AddChain(chain)
|
||||||
|
|
||||||
ifaceKey := expr.MetaKeyIIFNAME
|
ifaceKey := expr.MetaKeyIIFNAME
|
||||||
shiftDSTAddr := 0
|
shiftDSTAddr := 0
|
||||||
@ -429,7 +544,7 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = m.conn.AddRule(&nftables.Rule{
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
@ -444,12 +559,13 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
},
|
},
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||||
}
|
}
|
||||||
_ = m.conn.AddRule(&nftables.Rule{
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: table,
|
Table: table,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: expressions,
|
||||||
})
|
})
|
||||||
if err := m.conn.Flush(); err != nil {
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -458,16 +574,58 @@ func (m *Manager) createChainIfNotExists(
|
|||||||
|
|
||||||
// DeleteRule from the firewall by rule definition
|
// DeleteRule from the firewall by rule definition
|
||||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
nativeRule, ok := rule.(*Rule)
|
nativeRule, ok := rule.(*Rule)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid rule type")
|
return fmt.Errorf("invalid rule type")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.conn.DelRule(nativeRule.Rule); err != nil {
|
if nativeRule.nftRule == nil {
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.Flush()
|
if nativeRule.nftSet != nil {
|
||||||
|
// call twice of delete set element raises error
|
||||||
|
// so we need to check if element is already removed
|
||||||
|
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
||||||
|
if _, ok := m.setRemovedIPs[key]; !ok {
|
||||||
|
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
||||||
|
}
|
||||||
|
if err := m.sConn.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.setRemovedIPs[key] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.rulesetManager.deleteRule(nativeRule) {
|
||||||
|
// deleteRule indicates that we still have IP in the ruleset
|
||||||
|
// it means we should not remove the nftables rule but need to update set
|
||||||
|
// so we prepare IP to be removed from set on the next flush call
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
||||||
|
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
nativeRule.nftRule = nil
|
||||||
|
|
||||||
|
if nativeRule.nftSet != nil {
|
||||||
|
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
||||||
|
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
||||||
|
}
|
||||||
|
nativeRule.nftSet = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
@ -475,27 +633,116 @@ func (m *Manager) Reset() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
chains, err := m.conn.ListChains()
|
chains, err := m.rConn.ListChains()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of chains: %w", err)
|
return fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
}
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||||
m.conn.DelChain(c)
|
m.rConn.DelChain(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tables, err := m.conn.ListTables()
|
tables, err := m.rConn.ListTables()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list of tables: %w", err)
|
return fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == FilterTableName {
|
if t.Name == FilterTableName {
|
||||||
m.conn.DelTable(t)
|
m.rConn.DelTable(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.conn.Flush()
|
return m.rConn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush rule/chain/set operations from the buffer
|
||||||
|
//
|
||||||
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
|
func (m *Manager) Flush() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if err := m.flushWithBackoff(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set must be removed after flush rule changes
|
||||||
|
// otherwise we will get error
|
||||||
|
for _, s := range m.setRemoved {
|
||||||
|
m.rConn.FlushSet(s)
|
||||||
|
m.rConn.DelSet(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.setRemoved) > 0 {
|
||||||
|
if err := m.flushWithBackoff(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.setRemovedIPs = map[string]struct{}{}
|
||||||
|
m.setRemoved = map[string]*nftables.Set{}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) flushWithBackoff() (err error) {
|
||||||
|
backoff := 4
|
||||||
|
backoffTime := 1000 * time.Millisecond
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
if !strings.Contains(err.Error(), "busy") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error("failed to flush nftables, retrying...")
|
||||||
|
if i == backoff-1 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
time.Sleep(backoffTime)
|
||||||
|
backoffTime = backoffTime * 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
||||||
|
if table == nil || chain == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := m.rConn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range list {
|
||||||
|
if len(rule.UserData) != 0 {
|
||||||
|
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
||||||
|
log.Errorf("failed to set rule handle: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodePort(port fw.Port) []byte {
|
func encodePort(port fw.Port) []byte {
|
||||||
|
@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = manager.Reset()
|
err = manager.Reset()
|
||||||
@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
fw.RuleDirectionIN,
|
fw.RuleDirectionIN,
|
||||||
fw.ActionDrop,
|
fw.ActionDrop,
|
||||||
"",
|
"",
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
// test expectations:
|
// test expectations:
|
||||||
// 1) regular rule
|
// 1) regular rule
|
||||||
// 2) "accept extra routed traffic rule" for the interface
|
// 2) "accept extra routed traffic rule" for the interface
|
||||||
@ -135,6 +140,9 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
err = manager.DeleteRule(rule)
|
err = manager.DeleteRule(rule)
|
||||||
require.NoError(t, err, "failed to delete rule")
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
// test expectations:
|
// test expectations:
|
||||||
@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(); err != nil {
|
if err := manager.Reset(); err != nil {
|
||||||
@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
|
if i%100 == 0 {
|
||||||
|
err = manager.Flush()
|
||||||
|
require.NoError(t, err, "failed to flush")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -6,11 +6,14 @@ import (
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
*nftables.Rule
|
nftRule *nftables.Rule
|
||||||
id string
|
nftSet *nftables.Set
|
||||||
|
|
||||||
|
ruleID string
|
||||||
|
ip []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) GetRuleID() string {
|
||||||
return r.id
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
115
client/firewall/nftables/ruleset_linux.go
Normal file
115
client/firewall/nftables/ruleset_linux.go
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
||||||
|
type nftRuleset struct {
|
||||||
|
nftRule *nftables.Rule
|
||||||
|
nftSet *nftables.Set
|
||||||
|
issuedRules map[string]*Rule
|
||||||
|
rulesetID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type rulesetManager struct {
|
||||||
|
rulesets map[string]*nftRuleset
|
||||||
|
|
||||||
|
nftSetName2rulesetID map[string]string
|
||||||
|
issuedRuleID2rulesetID map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRuleManager() *rulesetManager {
|
||||||
|
return &rulesetManager{
|
||||||
|
rulesets: map[string]*nftRuleset{},
|
||||||
|
|
||||||
|
nftSetName2rulesetID: map[string]string{},
|
||||||
|
issuedRuleID2rulesetID: map[string]string{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
||||||
|
ruleset, ok := r.rulesets[rulesetID]
|
||||||
|
return ruleset, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) createRuleset(
|
||||||
|
rulesetID string,
|
||||||
|
nftRule *nftables.Rule,
|
||||||
|
nftSet *nftables.Set,
|
||||||
|
) *nftRuleset {
|
||||||
|
ruleset := nftRuleset{
|
||||||
|
rulesetID: rulesetID,
|
||||||
|
nftRule: nftRule,
|
||||||
|
nftSet: nftSet,
|
||||||
|
issuedRules: map[string]*Rule{},
|
||||||
|
}
|
||||||
|
r.rulesets[ruleset.rulesetID] = &ruleset
|
||||||
|
if nftSet != nil {
|
||||||
|
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
||||||
|
}
|
||||||
|
return &ruleset
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rulesetManager) addRule(
|
||||||
|
ruleset *nftRuleset,
|
||||||
|
ip []byte,
|
||||||
|
) (*Rule, error) {
|
||||||
|
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
||||||
|
return nil, fmt.Errorf("ruleset not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
nftRule: ruleset.nftRule,
|
||||||
|
nftSet: ruleset.nftSet,
|
||||||
|
ruleID: xid.New().String(),
|
||||||
|
ip: ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleset.issuedRules[rule.ruleID] = &rule
|
||||||
|
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteRule from ruleset and returns true if contains other rules
|
||||||
|
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
||||||
|
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleset := r.rulesets[rulesetID]
|
||||||
|
if ruleset.nftRule == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
||||||
|
delete(ruleset.issuedRules, rule.ruleID)
|
||||||
|
|
||||||
|
if len(ruleset.issuedRules) == 0 {
|
||||||
|
delete(r.rulesets, ruleset.rulesetID)
|
||||||
|
if rule.nftSet != nil {
|
||||||
|
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
||||||
|
//
|
||||||
|
// This is important to do, because after we add rule to the nftables we can't update it until
|
||||||
|
// we set correct handle value to it.
|
||||||
|
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
||||||
|
split := bytes.Split(nftRule.UserData, []byte(" "))
|
||||||
|
ruleset, ok := r.rulesets[string(split[0])]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("ruleset not found")
|
||||||
|
}
|
||||||
|
*ruleset.nftRule = *nftRule
|
||||||
|
return nil
|
||||||
|
}
|
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRulesetManager_createRuleset(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{
|
||||||
|
UserData: []byte(rulesetID),
|
||||||
|
}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||||
|
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
||||||
|
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_addRule(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.1.1")
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
||||||
|
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
||||||
|
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
||||||
|
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
||||||
|
|
||||||
|
ruleset2 := &nftRuleset{
|
||||||
|
rulesetID: "ruleset-2",
|
||||||
|
}
|
||||||
|
_, err = rulesetManager.addRule(ruleset2, ip)
|
||||||
|
require.Error(t, err, "addRule() should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_deleteRule(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.1.1")
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
|
||||||
|
ip2 := []byte("192.168.1.1")
|
||||||
|
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule2, "rule should not be nil")
|
||||||
|
|
||||||
|
hasNext := rulesetManager.deleteRule(rule)
|
||||||
|
require.True(t, hasNext, "deleteRule() should have returned true")
|
||||||
|
|
||||||
|
// Check that the rule is no longer in the manager.
|
||||||
|
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
||||||
|
|
||||||
|
hasNext = rulesetManager.deleteRule(rule2)
|
||||||
|
require.False(t, hasNext, "deleteRule() should have returned false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||||
|
// Add a rule to the ruleset.
|
||||||
|
ip := []byte("192.168.0.1")
|
||||||
|
|
||||||
|
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||||
|
require.NoError(t, err, "addRule() failed")
|
||||||
|
require.NotNil(t, rule, "rule should not be nil")
|
||||||
|
|
||||||
|
nftRuleCopy := nftRule
|
||||||
|
nftRuleCopy.Handle = 2
|
||||||
|
nftRuleCopy.UserData = []byte(rulesetID)
|
||||||
|
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
||||||
|
require.NoError(t, err, "setNftRuleHandle() failed")
|
||||||
|
// check correct work with references
|
||||||
|
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesetManager_getRuleset(t *testing.T) {
|
||||||
|
// Create a ruleset manager.
|
||||||
|
rulesetManager := newRuleManager()
|
||||||
|
// Create a ruleset.
|
||||||
|
rulesetID := "ruleset-1"
|
||||||
|
nftRule := nftables.Rule{}
|
||||||
|
nftSet := nftables.Set{
|
||||||
|
ID: 2,
|
||||||
|
}
|
||||||
|
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
||||||
|
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||||
|
|
||||||
|
find, ok := rulesetManager.getRuleset(rulesetID)
|
||||||
|
require.True(t, ok, "getRuleset() failed")
|
||||||
|
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
||||||
|
|
||||||
|
_, ok = rulesetManager.getRuleset("does-not-exist")
|
||||||
|
require.False(t, ok, "getRuleset() failed")
|
||||||
|
}
|
@ -84,6 +84,7 @@ func (m *Manager) AddFiltering(
|
|||||||
dPort *fw.Port,
|
dPort *fw.Port,
|
||||||
direction fw.RuleDirection,
|
direction fw.RuleDirection,
|
||||||
action fw.Action,
|
action fw.Action,
|
||||||
|
ipsetName string,
|
||||||
comment string,
|
comment string,
|
||||||
) (fw.Rule, error) {
|
) (fw.Rule, error) {
|
||||||
r := Rule{
|
r := Rule{
|
||||||
@ -181,6 +182,9 @@ func (m *Manager) Reset() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flush doesn't need to be implemented for this manager
|
||||||
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
return m.dropFilter(packetData, m.outgoingRules, false)
|
||||||
|
@ -63,7 +63,7 @@ func TestManagerAddFiltering(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -98,7 +98,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -111,7 +111,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
action = fw.ActionDrop
|
action = fw.ActionDrop
|
||||||
comment = "Test rule 2"
|
comment = "Test rule 2"
|
||||||
|
|
||||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -236,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -274,7 +274,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
comment := "Test rule"
|
||||||
|
|
||||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment)
|
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@ -390,9 +390,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []int{1000 + i}}
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
} else {
|
} else {
|
||||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
@ -33,9 +33,22 @@ type Manager interface {
|
|||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
manager firewall.Manager
|
manager firewall.Manager
|
||||||
rulesPairs map[string][]firewall.Rule
|
ipsetCounter int
|
||||||
mutex sync.Mutex
|
rulesPairs map[string][]firewall.Rule
|
||||||
|
mutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type ipsetInfo struct {
|
||||||
|
name string
|
||||||
|
ipCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||||
|
return &DefaultManager{
|
||||||
|
manager: fm,
|
||||||
|
rulesPairs: make(map[string][]firewall.Rule),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||||
@ -61,6 +74,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := d.manager.Flush(); err != nil {
|
||||||
|
log.Error("failed to flush firewall rules: ", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||||
|
|
||||||
enableSSH := (networkMap.PeerConfig != nil &&
|
enableSSH := (networkMap.PeerConfig != nil &&
|
||||||
@ -108,8 +127,32 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
|
|
||||||
applyFailed := false
|
applyFailed := false
|
||||||
newRulePairs := make(map[string][]firewall.Rule)
|
newRulePairs := make(map[string][]firewall.Rule)
|
||||||
|
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
||||||
|
|
||||||
|
// calculate which IP's can be grouped in by which ipset
|
||||||
|
// to do that we use rule selector (which is just rule properties without IP's)
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r)
|
selector := d.getRuleGroupingSelector(r)
|
||||||
|
ipset, ok := ipsetByRuleSelectors[selector]
|
||||||
|
if !ok {
|
||||||
|
ipset = &ipsetInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ipset.ipCount++
|
||||||
|
ipsetByRuleSelectors[selector] = ipset
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rules {
|
||||||
|
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||||
|
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||||
|
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
||||||
|
ipsetName := ""
|
||||||
|
if ipset.name == "" {
|
||||||
|
d.ipsetCounter++
|
||||||
|
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||||
|
}
|
||||||
|
ipsetName = ipset.name
|
||||||
|
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||||
applyFailed = true
|
applyFailed = true
|
||||||
@ -154,7 +197,10 @@ func (d *DefaultManager) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (string, []firewall.Rule, error) {
|
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||||
|
r *mgmProto.FirewallRule,
|
||||||
|
ipsetName string,
|
||||||
|
) (string, []firewall.Rule, error) {
|
||||||
ip := net.ParseIP(r.PeerIP)
|
ip := net.ParseIP(r.PeerIP)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||||
@ -190,9 +236,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
|
|||||||
var err error
|
var err error
|
||||||
switch r.Direction {
|
switch r.Direction {
|
||||||
case mgmProto.FirewallRule_IN:
|
case mgmProto.FirewallRule_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, "")
|
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||||
case mgmProto.FirewallRule_OUT:
|
case mgmProto.FirewallRule_OUT:
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, "")
|
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
}
|
}
|
||||||
@ -205,9 +251,17 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
|
|||||||
return ruleID, rules, nil
|
return ruleID, rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
func (d *DefaultManager) addInRules(
|
||||||
|
ip net.IP,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
comment string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment)
|
rule, err := d.manager.AddFiltering(
|
||||||
|
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@ -217,7 +271,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment)
|
rule, err = d.manager.AddFiltering(
|
||||||
|
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@ -225,9 +280,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return append(rules, rule), nil
|
return append(rules, rule), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
func (d *DefaultManager) addOutRules(
|
||||||
|
ip net.IP,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
comment string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment)
|
rule, err := d.manager.AddFiltering(
|
||||||
|
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@ -237,7 +300,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
|
|||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment)
|
rule, err = d.manager.AddFiltering(
|
||||||
|
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||||
}
|
}
|
||||||
@ -282,6 +346,10 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
in := protoMatch{}
|
in := protoMatch{}
|
||||||
out := protoMatch{}
|
out := protoMatch{}
|
||||||
|
|
||||||
|
// trace which type of protocols was squashed
|
||||||
|
squashedRules := []*mgmProto.FirewallRule{}
|
||||||
|
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||||
|
|
||||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||||
// But we zeroed the IP's for protocol if:
|
// But we zeroed the IP's for protocol if:
|
||||||
@ -298,12 +366,22 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
if _, ok := protocols[r.Protocol]; !ok {
|
if _, ok := protocols[r.Protocol]; !ok {
|
||||||
protocols[r.Protocol] = map[string]int{}
|
protocols[r.Protocol] = map[string]int{}
|
||||||
}
|
}
|
||||||
match := protocols[r.Protocol]
|
|
||||||
|
|
||||||
if _, ok := match[r.PeerIP]; ok {
|
// special case, when we recieve this all network IP address
|
||||||
|
// it means that rules for that protocol was already optimized on the
|
||||||
|
// management side
|
||||||
|
if r.PeerIP == "0.0.0.0" {
|
||||||
|
squashedRules = append(squashedRules, r)
|
||||||
|
squashedProtocols[r.Protocol] = struct{}{}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
match[r.PeerIP] = i
|
|
||||||
|
ipset := protocols[r.Protocol]
|
||||||
|
|
||||||
|
if _, ok := ipset[r.PeerIP]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipset[r.PeerIP] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
for i, r := range networkMap.FirewallRules {
|
||||||
@ -324,9 +402,6 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
mgmProto.FirewallRule_UDP,
|
mgmProto.FirewallRule_UDP,
|
||||||
}
|
}
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
|
||||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
|
||||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
||||||
for _, protocol := range protocolOrders {
|
for _, protocol := range protocolOrders {
|
||||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||||
@ -382,6 +457,11 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
return append(rules, squashedRules...), squashedProtocols
|
return append(rules, squashedRules...), squashedProtocols
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||||
|
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||||
|
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||||
|
}
|
||||||
|
|
||||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case mgmProto.FirewallRule_TCP:
|
case mgmProto.FirewallRule_TCP:
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -18,10 +17,7 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &DefaultManager{
|
return newDefaultManager(fm), nil
|
||||||
manager: fm,
|
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
@ -29,8 +29,5 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DefaultManager{
|
return newDefaultManager(fm), nil
|
||||||
manager: fm,
|
|
||||||
rulesPairs: make(map[string][]firewall.Rule),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
@ -2,9 +2,11 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"fmt"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
@ -240,7 +242,15 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun
|
|||||||
peersExists := make(map[string]struct{})
|
peersExists := make(map[string]struct{})
|
||||||
rules := make([]*FirewallRule, 0)
|
rules := make([]*FirewallRule, 0)
|
||||||
peers := make([]*Peer, 0)
|
peers := make([]*Peer, 0)
|
||||||
|
|
||||||
|
all, err := a.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get group all: %v", err)
|
||||||
|
all = &Group{}
|
||||||
|
}
|
||||||
|
|
||||||
return func(rule *PolicyRule, groupPeers []*Peer, direction int) {
|
return func(rule *PolicyRule, groupPeers []*Peer, direction int) {
|
||||||
|
isAll := (len(all.Peers) - 1) == len(groupPeers)
|
||||||
for _, peer := range groupPeers {
|
for _, peer := range groupPeers {
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
continue
|
continue
|
||||||
@ -250,29 +260,33 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun
|
|||||||
peersExists[peer.ID] = struct{}{}
|
peersExists[peer.ID] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
fwRule := FirewallRule{
|
fr := FirewallRule{
|
||||||
PeerIP: peer.IP.String(),
|
PeerIP: peer.IP.String(),
|
||||||
Direction: direction,
|
Direction: direction,
|
||||||
Action: string(rule.Action),
|
Action: string(rule.Action),
|
||||||
Protocol: string(rule.Protocol),
|
Protocol: string(rule.Protocol),
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := fmt.Sprintf("%s%d", peer.ID+peer.IP.String(), direction)
|
if isAll {
|
||||||
ruleID += string(rule.Protocol) + string(rule.Action) + strings.Join(rule.Ports, ",")
|
fr.PeerIP = "0.0.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||||
|
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
|
||||||
if _, ok := rulesExists[ruleID]; ok {
|
if _, ok := rulesExists[ruleID]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
rulesExists[ruleID] = struct{}{}
|
rulesExists[ruleID] = struct{}{}
|
||||||
|
|
||||||
if len(rule.Ports) == 0 {
|
if len(rule.Ports) == 0 {
|
||||||
rules = append(rules, &fwRule)
|
rules = append(rules, &fr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, port := range rule.Ports {
|
for _, port := range rule.Ports {
|
||||||
addRule := fwRule
|
pr := fr // clone rule and add set new port
|
||||||
addRule.Port = port
|
pr.Port = port
|
||||||
rules = append(rules, &addRule)
|
rules = append(rules, &pr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, func() ([]*Peer, []*FirewallRule) {
|
}, func() ([]*Peer, []*FirewallRule) {
|
||||||
|
@ -126,6 +126,20 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
assert.Contains(t, peers, account.Peers["peerF"])
|
assert.Contains(t, peers, account.Peers["peerF"])
|
||||||
|
|
||||||
epectedFirewallRules := []*FirewallRule{
|
epectedFirewallRules := []*FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "0.0.0.0",
|
||||||
|
Direction: firewallRuleDirectionIN,
|
||||||
|
Action: "accept",
|
||||||
|
Protocol: "all",
|
||||||
|
Port: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "0.0.0.0",
|
||||||
|
Direction: firewallRuleDirectionOUT,
|
||||||
|
Action: "accept",
|
||||||
|
Protocol: "all",
|
||||||
|
Port: "",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
PeerIP: "100.65.14.88",
|
PeerIP: "100.65.14.88",
|
||||||
Direction: firewallRuleDirectionIN,
|
Direction: firewallRuleDirectionIN,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user