From e69ec6ab6a477034c34b945d58c03ee0a9a14091 Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Tue, 18 Jul 2023 13:12:50 +0400 Subject: [PATCH] Optimize ACL performance (#994) * Optimize rules with All groups * Use IP sets in ACLs (nftables implementation) * Fix squash rule when we receive optimized rules list from management --- client/firewall/firewall.go | 4 + client/firewall/iptables/manager_linux.go | 4 + .../firewall/iptables/manager_linux_test.go | 10 +- client/firewall/nftables/manager_linux.go | 363 +++++++++++++++--- .../firewall/nftables/manager_linux_test.go | 23 +- client/firewall/nftables/rule_linux.go | 9 +- client/firewall/nftables/ruleset_linux.go | 115 ++++++ .../firewall/nftables/ruleset_linux_test.go | 122 ++++++ client/firewall/uspfilter/uspfilter.go | 4 + client/firewall/uspfilter/uspfilter_test.go | 14 +- client/internal/acl/manager.go | 118 +++++- client/internal/acl/manager_create.go | 6 +- client/internal/acl/manager_create_linux.go | 5 +- management/server/policy.go | 30 +- management/server/policy_test.go | 14 + 15 files changed, 727 insertions(+), 114 deletions(-) create mode 100644 client/firewall/nftables/ruleset_linux.go create mode 100644 client/firewall/nftables/ruleset_linux_test.go diff --git a/client/firewall/firewall.go b/client/firewall/firewall.go index f91adb7c1..5d003e2f0 100644 --- a/client/firewall/firewall.go +++ b/client/firewall/firewall.go @@ -51,6 +51,7 @@ type Manager interface { dPort *Port, direction RuleDirection, action Action, + ipsetName string, comment string, ) (Rule, error) @@ -60,5 +61,8 @@ type Manager interface { // Reset firewall to the default state Reset() error + // Flush the changes to firewall controller + Flush() error + // TODO: migrate routemanager firewal actions to this interface } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 6ddab0b8f..a4a7a6c3b 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -92,6 +92,7 @@ func (m *Manager) AddFiltering( dPort *fw.Port, direction fw.RuleDirection, action fw.Action, + ipsetName string, comment string, ) (fw.Rule, error) { m.mutex.Lock() @@ -202,6 +203,9 @@ func (m *Manager) Reset() error { return nil } +// Flush doesn't need to be implemented for this manager +func (m *Manager) Flush() error { return nil } + // reset firewall chain, clear it and drop it func (m *Manager) reset(client *iptables.IPTables, table string) error { ok, err := client.ChainExists(table, ChainInputFilterName) diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 7c78a4ee2..cbf9d4c76 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -68,7 +68,7 @@ func TestIptablesManager(t *testing.T) { t.Run("add first rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") + rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...) @@ -81,7 +81,7 @@ func TestIptablesManager(t *testing.T) { Values: []int{8043: 8046}, } rule2, err = manager.AddFiltering( - ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range") + ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") require.NoError(t, err, "failed to add rule") checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...) @@ -107,7 +107,7 @@ func TestIptablesManager(t *testing.T) { // add second rule ip := net.ParseIP("10.20.0.3") port := &fw.Port{Values: []int{5353}} - _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic") + _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") err = manager.Reset() @@ -167,9 +167,9 @@ func TestIptablesCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") + _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") + _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index e94f93f9e..71085276d 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -6,12 +6,14 @@ import ( "fmt" "net" "net/netip" + "strconv" "strings" "sync" + "time" "github.com/google/nftables" "github.com/google/nftables/expr" - "github.com/google/uuid" + log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall" @@ -29,11 +31,14 @@ const ( FilterOutputChainName = "netbird-acl-output-filter" ) +var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + // Manager of iptables firewall type Manager struct { mutex sync.Mutex - conn *nftables.Conn + rConn *nftables.Conn + sConn *nftables.Conn tableIPv4 *nftables.Table tableIPv6 *nftables.Table @@ -43,6 +48,10 @@ type Manager struct { filterInputChainIPv6 *nftables.Chain filterOutputChainIPv6 *nftables.Chain + rulesetManager *rulesetManager + setRemovedIPs map[string]struct{} + setRemoved map[string]*nftables.Set + wgIface iFaceMapper } @@ -54,8 +63,23 @@ type iFaceMapper interface { // Create nftables firewall manager func Create(wgIface iFaceMapper) (*Manager, error) { + // sConn is used for creating sets and adding/removing elements from them + // it's differ then rConn (which does create new conn for each flush operation) + // and is permanent. Using same connection for booth type of operations + // overloads netlink with high amount of rules ( > 10000) + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return nil, err + } + m := &Manager{ - conn: &nftables.Conn{}, + rConn: &nftables.Conn{}, + sConn: sConn, + + rulesetManager: newRuleManager(), + setRemovedIPs: map[string]struct{}{}, + setRemoved: map[string]*nftables.Set{}, + wgIface: wgIface, } @@ -77,6 +101,7 @@ func (m *Manager) AddFiltering( dPort *fw.Port, direction fw.RuleDirection, action fw.Action, + ipsetName string, comment string, ) (fw.Rule, error) { m.mutex.Lock() @@ -84,6 +109,7 @@ func (m *Manager) AddFiltering( var ( err error + ipset *nftables.Set table *nftables.Table chain *nftables.Chain ) @@ -107,6 +133,46 @@ func (m *Manager) AddFiltering( return nil, err } + rawIP := ip.To4() + if rawIP == nil { + rawIP = ip.To16() + } + + rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName) + + if ipsetName != "" { + // if we already have set with given name, just add ip to the set + // and return rule with new ID in other case let's create rule + // with fresh created set and set element + + var isSetNew bool + ipset, err := m.rConn.GetSetByName(table, ipsetName) + if err != nil { + if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil { + return nil, fmt.Errorf("get set name: %v", err) + } + isSetNew = true + } + + if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil { + return nil, fmt.Errorf("add set element for the first time: %v", err) + } + if err := m.sConn.Flush(); err != nil { + return nil, fmt.Errorf("flush add elements: %v", err) + } + + if !isSetNew { + // if we already have nftables rules with set for given direction + // just add new rule to the ruleset and return new fw.Rule object + + if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok { + return m.rulesetManager.addRule(ruleset, rawIP) + } + // if ipset exists but it is not linked to rule for given direction + // create new rule for direction and bind ipset to it later + } + } + ifaceKey := expr.MetaKeyIIFNAME if direction == fw.RuleDirectionOUT { ifaceKey = expr.MetaKeyOIFNAME @@ -146,39 +212,47 @@ func (m *Manager) AddFiltering( }) } - // don't use IP matching if IP is ip 0.0.0.0 - if s := ip.String(); s != "0.0.0.0" && s != "::" { + // check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value + // in that case not add IP match expression into the rule definition + if !bytes.HasPrefix(anyIP, rawIP) { // source address position - var adrLen, adrOffset uint32 - if ip.To4() == nil { - adrLen = 16 - adrOffset = 8 - } else { - adrLen = 4 - adrOffset = 12 + addrLen := uint32(len(rawIP)) + addrOffset := uint32(12) + if addrLen == 16 { + addrOffset = 8 } // change to destination address position if need if direction == fw.RuleDirectionOUT { - adrOffset += adrLen + addrOffset += addrLen } - ipToAdd, _ := netip.AddrFromSlice(ip) - add := ipToAdd.Unmap() - expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: adrOffset, - Len: adrLen, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: add.AsSlice(), + Offset: addrOffset, + Len: addrLen, }, ) + // add individual IP for match if no ipset defined + if ipset == nil { + expressions = append(expressions, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: rawIP, + }, + ) + } else { + expressions = append(expressions, + &expr.Lookup{ + SourceRegister: 1, + SetName: ipsetName, + SetID: ipset.ID, + }, + ) + } } if sPort != nil && len(sPort.Values) != 0 { @@ -219,39 +293,76 @@ func (m *Manager) AddFiltering( expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) } - id := uuid.New().String() - userData := []byte(strings.Join([]string{id, comment}, " ")) + userData := []byte(strings.Join([]string{rulesetID, comment}, " ")) - _ = m.conn.InsertRule(&nftables.Rule{ + rule := m.rConn.InsertRule(&nftables.Rule{ Table: table, Chain: chain, Position: 0, Exprs: expressions, UserData: userData, }) - - if err := m.conn.Flush(); err != nil { - return nil, err + if err := m.rConn.Flush(); err != nil { + return nil, fmt.Errorf("flush insert rule: %v", err) } - list, err := m.conn.GetRules(table, chain) - if err != nil { - return nil, err + ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset) + return m.rulesetManager.addRule(ruleset, rawIP) +} + +// getRulesetID returns ruleset ID based on given parameters +func (m *Manager) getRulesetID( + ip net.IP, + proto fw.Protocol, + sPort *fw.Port, + dPort *fw.Port, + direction fw.RuleDirection, + action fw.Action, + ipsetName string, +) string { + rulesetID := ":" + strconv.Itoa(int(direction)) + ":" + if sPort != nil { + rulesetID += sPort.String() + } + rulesetID += ":" + if dPort != nil { + rulesetID += dPort.String() + } + rulesetID += ":" + rulesetID += strconv.Itoa(int(action)) + if ipsetName == "" { + return "ip:" + ip.String() + rulesetID + } + return "set:" + ipsetName + rulesetID +} + +// createSet in given table by name +func (m *Manager) createSet( + table *nftables.Table, + rawIP []byte, + name string, +) (*nftables.Set, error) { + keyType := nftables.TypeIPAddr + if len(rawIP) == 16 { + keyType = nftables.TypeIP6Addr + } + // else we create new ipset and continue creating rule + ipset := &nftables.Set{ + Name: name, + Table: table, + Dynamic: true, + KeyType: keyType, } - // Add the rule to the chain - rule := &Rule{id: id} - for _, r := range list { - if bytes.Equal(r.UserData, userData) { - rule.Rule = r - break - } - } - if rule.Rule == nil { - return nil, fmt.Errorf("rule not found") + if err := m.rConn.AddSet(ipset, nil); err != nil { + return nil, fmt.Errorf("create set: %v", err) } - return rule, nil + if err := m.rConn.Flush(); err != nil { + return nil, fmt.Errorf("flush created set: %v", err) + } + + return ipset, nil } // chain returns the chain for the given IP address with specific settings @@ -315,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { } func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) { - tables, err := m.conn.ListTablesOfFamily(family) + tables, err := m.rConn.ListTablesOfFamily(family) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } @@ -326,7 +437,11 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables } } - return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil + table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}) + if err := m.rConn.Flush(); err != nil { + return nil, err + } + return table, nil } func (m *Manager) createChainIfNotExists( @@ -341,7 +456,7 @@ func (m *Manager) createChainIfNotExists( return nil, err } - chains, err := m.conn.ListChainsOfTableFamily(family) + chains, err := m.rConn.ListChainsOfTableFamily(family) if err != nil { return nil, fmt.Errorf("list of chains: %w", err) } @@ -362,7 +477,7 @@ func (m *Manager) createChainIfNotExists( Policy: &polAccept, } - chain = m.conn.AddChain(chain) + chain = m.rConn.AddChain(chain) ifaceKey := expr.MetaKeyIIFNAME shiftDSTAddr := 0 @@ -429,7 +544,7 @@ func (m *Manager) createChainIfNotExists( ) } - _ = m.conn.AddRule(&nftables.Rule{ + _ = m.rConn.AddRule(&nftables.Rule{ Table: table, Chain: chain, Exprs: expressions, @@ -444,12 +559,13 @@ func (m *Manager) createChainIfNotExists( }, &expr.Verdict{Kind: expr.VerdictDrop}, } - _ = m.conn.AddRule(&nftables.Rule{ + _ = m.rConn.AddRule(&nftables.Rule{ Table: table, Chain: chain, Exprs: expressions, }) - if err := m.conn.Flush(); err != nil { + + if err := m.rConn.Flush(); err != nil { return nil, err } @@ -458,16 +574,58 @@ func (m *Manager) createChainIfNotExists( // DeleteRule from the firewall by rule definition func (m *Manager) DeleteRule(rule fw.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + nativeRule, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") } - if err := m.conn.DelRule(nativeRule.Rule); err != nil { - return err + if nativeRule.nftRule == nil { + return nil } - return m.conn.Flush() + if nativeRule.nftSet != nil { + // call twice of delete set element raises error + // so we need to check if element is already removed + key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip) + if _, ok := m.setRemovedIPs[key]; !ok { + err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}}) + if err != nil { + log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err) + } + if err := m.sConn.Flush(); err != nil { + return err + } + m.setRemovedIPs[key] = struct{}{} + } + } + + if m.rulesetManager.deleteRule(nativeRule) { + // deleteRule indicates that we still have IP in the ruleset + // it means we should not remove the nftables rule but need to update set + // so we prepare IP to be removed from set on the next flush call + return nil + } + + // ruleset doesn't contain IP anymore (or contains only one), remove nft rule + if err := m.rConn.DelRule(nativeRule.nftRule); err != nil { + log.Errorf("failed to delete rule: %v", err) + } + if err := m.rConn.Flush(); err != nil { + return err + } + nativeRule.nftRule = nil + + if nativeRule.nftSet != nil { + if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok { + m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet + } + nativeRule.nftSet = nil + } + + return nil } // Reset firewall to the default state @@ -475,27 +633,116 @@ func (m *Manager) Reset() error { m.mutex.Lock() defer m.mutex.Unlock() - chains, err := m.conn.ListChains() + chains, err := m.rConn.ListChains() if err != nil { return fmt.Errorf("list of chains: %w", err) } for _, c := range chains { if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { - m.conn.DelChain(c) + m.rConn.DelChain(c) } } - tables, err := m.conn.ListTables() + tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list of tables: %w", err) } for _, t := range tables { if t.Name == FilterTableName { - m.conn.DelTable(t) + m.rConn.DelTable(t) } } - return m.conn.Flush() + return m.rConn.Flush() +} + +// Flush rule/chain/set operations from the buffer +// +// Method also get all rules after flush and refreshes handle values in the rulesets +func (m *Manager) Flush() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if err := m.flushWithBackoff(); err != nil { + return err + } + + // set must be removed after flush rule changes + // otherwise we will get error + for _, s := range m.setRemoved { + m.rConn.FlushSet(s) + m.rConn.DelSet(s) + } + + if len(m.setRemoved) > 0 { + if err := m.flushWithBackoff(); err != nil { + return err + } + } + + m.setRemovedIPs = map[string]struct{}{} + m.setRemoved = map[string]*nftables.Set{} + + if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil { + log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err) + } + + if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil { + log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) + } + + if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil { + log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err) + } + + if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil { + log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err) + } + + return nil +} + +func (m *Manager) flushWithBackoff() (err error) { + backoff := 4 + backoffTime := 1000 * time.Millisecond + for i := 0; ; i++ { + err = m.rConn.Flush() + if err != nil { + if !strings.Contains(err.Error(), "busy") { + return + } + log.Error("failed to flush nftables, retrying...") + if i == backoff-1 { + return err + } + time.Sleep(backoffTime) + backoffTime = backoffTime * 2 + continue + } + break + } + return +} + +func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error { + if table == nil || chain == nil { + return nil + } + + list, err := m.rConn.GetRules(table, chain) + if err != nil { + return err + } + + for _, rule := range list { + if len(rule.UserData) != 0 { + if err := m.rulesetManager.setNftRuleHandle(rule); err != nil { + log.Errorf("failed to set rule handle: %v", err) + } + } + } + + return nil } func encodePort(port fw.Port) []byte { diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 9c0d247c5..164d5d0dc 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) { // just check on the local interface manager, err := Create(mock) require.NoError(t, err) - time.Sleep(time.Second) + time.Sleep(time.Second * 3) defer func() { err = manager.Reset() @@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) { fw.RuleDirectionIN, fw.ActionDrop, "", + "", ) require.NoError(t, err, "failed to add rule") + err = manager.Flush() + require.NoError(t, err, "failed to flush") + rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) require.NoError(t, err, "failed to get rules") + // test expectations: // 1) regular rule // 2) "accept extra routed traffic rule" for the interface @@ -135,6 +140,9 @@ func TestNftablesManager(t *testing.T) { err = manager.DeleteRule(rule) require.NoError(t, err, "failed to delete rule") + err = manager.Flush() + require.NoError(t, err, "failed to flush") + rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) require.NoError(t, err, "failed to get rules") // test expectations: @@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) { // just check on the local interface manager, err := Create(mock) require.NoError(t, err) - time.Sleep(time.Second) + time.Sleep(time.Second * 3) defer func() { if err := manager.Reset(); err != nil { @@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") + _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") + _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } - require.NoError(t, err, "failed to add rule") + + if i%100 == 0 { + err = manager.Flush() + require.NoError(t, err, "failed to flush") + } } + t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax)) }) } diff --git a/client/firewall/nftables/rule_linux.go b/client/firewall/nftables/rule_linux.go index 7fe0dcb5e..98d1147cd 100644 --- a/client/firewall/nftables/rule_linux.go +++ b/client/firewall/nftables/rule_linux.go @@ -6,11 +6,14 @@ import ( // Rule to handle management of rules type Rule struct { - *nftables.Rule - id string + nftRule *nftables.Rule + nftSet *nftables.Set + + ruleID string + ip []byte } // GetRuleID returns the rule id func (r *Rule) GetRuleID() string { - return r.id + return r.ruleID } diff --git a/client/firewall/nftables/ruleset_linux.go b/client/firewall/nftables/ruleset_linux.go new file mode 100644 index 000000000..536a5ee18 --- /dev/null +++ b/client/firewall/nftables/ruleset_linux.go @@ -0,0 +1,115 @@ +package nftables + +import ( + "bytes" + "fmt" + + "github.com/google/nftables" + "github.com/rs/xid" +) + +// nftRuleset links native firewall rule and ipset to ACL generated rules +type nftRuleset struct { + nftRule *nftables.Rule + nftSet *nftables.Set + issuedRules map[string]*Rule + rulesetID string +} + +type rulesetManager struct { + rulesets map[string]*nftRuleset + + nftSetName2rulesetID map[string]string + issuedRuleID2rulesetID map[string]string +} + +func newRuleManager() *rulesetManager { + return &rulesetManager{ + rulesets: map[string]*nftRuleset{}, + + nftSetName2rulesetID: map[string]string{}, + issuedRuleID2rulesetID: map[string]string{}, + } +} + +func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) { + ruleset, ok := r.rulesets[rulesetID] + return ruleset, ok +} + +func (r *rulesetManager) createRuleset( + rulesetID string, + nftRule *nftables.Rule, + nftSet *nftables.Set, +) *nftRuleset { + ruleset := nftRuleset{ + rulesetID: rulesetID, + nftRule: nftRule, + nftSet: nftSet, + issuedRules: map[string]*Rule{}, + } + r.rulesets[ruleset.rulesetID] = &ruleset + if nftSet != nil { + r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID + } + return &ruleset +} + +func (r *rulesetManager) addRule( + ruleset *nftRuleset, + ip []byte, +) (*Rule, error) { + if _, ok := r.rulesets[ruleset.rulesetID]; !ok { + return nil, fmt.Errorf("ruleset not found") + } + + rule := Rule{ + nftRule: ruleset.nftRule, + nftSet: ruleset.nftSet, + ruleID: xid.New().String(), + ip: ip, + } + + ruleset.issuedRules[rule.ruleID] = &rule + r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID + + return &rule, nil +} + +// deleteRule from ruleset and returns true if contains other rules +func (r *rulesetManager) deleteRule(rule *Rule) bool { + rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID] + if !ok { + return false + } + + ruleset := r.rulesets[rulesetID] + if ruleset.nftRule == nil { + return false + } + delete(r.issuedRuleID2rulesetID, rule.ruleID) + delete(ruleset.issuedRules, rule.ruleID) + + if len(ruleset.issuedRules) == 0 { + delete(r.rulesets, ruleset.rulesetID) + if rule.nftSet != nil { + delete(r.nftSetName2rulesetID, rule.nftSet.Name) + } + return false + } + return true +} + +// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number +// +// This is important to do, because after we add rule to the nftables we can't update it until +// we set correct handle value to it. +func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error { + split := bytes.Split(nftRule.UserData, []byte(" ")) + ruleset, ok := r.rulesets[string(split[0])] + if !ok { + return fmt.Errorf("ruleset not found") + } + *ruleset.nftRule = *nftRule + return nil +} diff --git a/client/firewall/nftables/ruleset_linux_test.go b/client/firewall/nftables/ruleset_linux_test.go new file mode 100644 index 000000000..74b37d8f8 --- /dev/null +++ b/client/firewall/nftables/ruleset_linux_test.go @@ -0,0 +1,122 @@ +package nftables + +import ( + "testing" + + "github.com/google/nftables" + "github.com/stretchr/testify/require" +) + +func TestRulesetManager_createRuleset(t *testing.T) { + // Create a ruleset manager. + rulesetManager := newRuleManager() + + // Create a ruleset. + rulesetID := "ruleset-1" + nftRule := nftables.Rule{ + UserData: []byte(rulesetID), + } + ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil) + require.NotNil(t, ruleset, "createRuleset() failed") + require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect") + require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect") +} + +func TestRulesetManager_addRule(t *testing.T) { + // Create a ruleset manager. + rulesetManager := newRuleManager() + + // Create a ruleset. + rulesetID := "ruleset-1" + nftRule := nftables.Rule{} + ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil) + + // Add a rule to the ruleset. + ip := []byte("192.168.1.1") + rule, err := rulesetManager.addRule(ruleset, ip) + require.NoError(t, err, "addRule() failed") + require.NotNil(t, rule, "rule should not be nil") + require.NotEqual(t, rule.ruleID, "ruleID is empty") + require.EqualValues(t, rule.ip, ip, "ip is incorrect") + require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset") + require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager") + + ruleset2 := &nftRuleset{ + rulesetID: "ruleset-2", + } + _, err = rulesetManager.addRule(ruleset2, ip) + require.Error(t, err, "addRule() should have failed") +} + +func TestRulesetManager_deleteRule(t *testing.T) { + // Create a ruleset manager. + rulesetManager := newRuleManager() + + // Create a ruleset. + rulesetID := "ruleset-1" + nftRule := nftables.Rule{} + ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil) + + // Add a rule to the ruleset. + ip := []byte("192.168.1.1") + rule, err := rulesetManager.addRule(ruleset, ip) + require.NoError(t, err, "addRule() failed") + require.NotNil(t, rule, "rule should not be nil") + + ip2 := []byte("192.168.1.1") + rule2, err := rulesetManager.addRule(ruleset, ip2) + require.NoError(t, err, "addRule() failed") + require.NotNil(t, rule2, "rule should not be nil") + + hasNext := rulesetManager.deleteRule(rule) + require.True(t, hasNext, "deleteRule() should have returned true") + + // Check that the rule is no longer in the manager. + require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted") + + hasNext = rulesetManager.deleteRule(rule2) + require.False(t, hasNext, "deleteRule() should have returned false") +} + +func TestRulesetManager_setNftRuleHandle(t *testing.T) { + // Create a ruleset manager. + rulesetManager := newRuleManager() + // Create a ruleset. + rulesetID := "ruleset-1" + nftRule := nftables.Rule{} + ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil) + // Add a rule to the ruleset. + ip := []byte("192.168.0.1") + + rule, err := rulesetManager.addRule(ruleset, ip) + require.NoError(t, err, "addRule() failed") + require.NotNil(t, rule, "rule should not be nil") + + nftRuleCopy := nftRule + nftRuleCopy.Handle = 2 + nftRuleCopy.UserData = []byte(rulesetID) + err = rulesetManager.setNftRuleHandle(&nftRuleCopy) + require.NoError(t, err, "setNftRuleHandle() failed") + // check correct work with references + require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect") +} + +func TestRulesetManager_getRuleset(t *testing.T) { + // Create a ruleset manager. + rulesetManager := newRuleManager() + // Create a ruleset. + rulesetID := "ruleset-1" + nftRule := nftables.Rule{} + nftSet := nftables.Set{ + ID: 2, + } + ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet) + require.NotNil(t, ruleset, "createRuleset() failed") + + find, ok := rulesetManager.getRuleset(rulesetID) + require.True(t, ok, "getRuleset() failed") + require.Equal(t, ruleset, find, "getRulesetBySetID() failed") + + _, ok = rulesetManager.getRuleset("does-not-exist") + require.False(t, ok, "getRuleset() failed") +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 5cc215256..3dead1db4 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -84,6 +84,7 @@ func (m *Manager) AddFiltering( dPort *fw.Port, direction fw.RuleDirection, action fw.Action, + ipsetName string, comment string, ) (fw.Rule, error) { r := Rule{ @@ -181,6 +182,9 @@ func (m *Manager) Reset() error { return nil } +// Flush doesn't need to be implemented for this manager +func (m *Manager) Flush() error { return nil } + // DropOutgoing filter outgoing packets func (m *Manager) DropOutgoing(packetData []byte) bool { return m.dropFilter(packetData, m.outgoingRules, false) diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index c7f38a44f..bc94f59c1 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -63,7 +63,7 @@ func TestManagerAddFiltering(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) + rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -98,7 +98,7 @@ func TestManagerDeleteRule(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) + rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -111,7 +111,7 @@ func TestManagerDeleteRule(t *testing.T) { action = fw.ActionDrop comment = "Test rule 2" - rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment) + rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -236,7 +236,7 @@ func TestManagerReset(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - _, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment) + _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -274,7 +274,7 @@ func TestNotMatchByIP(t *testing.T) { action := fw.ActionAccept comment := "Test rule" - _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment) + _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -390,9 +390,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic") + _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic") + _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 95f2c253e..d4b6930a0 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -33,9 +33,22 @@ type Manager interface { // DefaultManager uses firewall manager to handle type DefaultManager struct { - manager firewall.Manager - rulesPairs map[string][]firewall.Rule - mutex sync.Mutex + manager firewall.Manager + ipsetCounter int + rulesPairs map[string][]firewall.Rule + mutex sync.Mutex +} + +type ipsetInfo struct { + name string + ipCount int +} + +func newDefaultManager(fm firewall.Manager) *DefaultManager { + return &DefaultManager{ + manager: fm, + rulesPairs: make(map[string][]firewall.Rule), + } } // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. @@ -61,6 +74,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { return } + defer func() { + if err := d.manager.Flush(); err != nil { + log.Error("failed to flush firewall rules: ", err) + } + }() + rules, squashedProtocols := d.squashAcceptRules(networkMap) enableSSH := (networkMap.PeerConfig != nil && @@ -108,8 +127,32 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { applyFailed := false newRulePairs := make(map[string][]firewall.Rule) + ipsetByRuleSelectors := make(map[string]*ipsetInfo) + + // calculate which IP's can be grouped in by which ipset + // to do that we use rule selector (which is just rule properties without IP's) for _, r := range rules { - pairID, rulePair, err := d.protoRuleToFirewallRule(r) + selector := d.getRuleGroupingSelector(r) + ipset, ok := ipsetByRuleSelectors[selector] + if !ok { + ipset = &ipsetInfo{} + } + + ipset.ipCount++ + ipsetByRuleSelectors[selector] = ipset + } + + for _, r := range rules { + // if this rule is member of rule selection with more than DefaultIPsCountForSet + // it's IP address can be used in the ipset for firewall manager which supports it + ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)] + ipsetName := "" + if ipset.name == "" { + d.ipsetCounter++ + ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter) + } + ipsetName = ipset.name + pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) if err != nil { log.Errorf("failed to apply firewall rule: %+v, %v", r, err) applyFailed = true @@ -154,7 +197,10 @@ func (d *DefaultManager) Stop() { } } -func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (string, []firewall.Rule, error) { +func (d *DefaultManager) protoRuleToFirewallRule( + r *mgmProto.FirewallRule, + ipsetName string, +) (string, []firewall.Rule, error) { ip := net.ParseIP(r.PeerIP) if ip == nil { return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") @@ -190,9 +236,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri var err error switch r.Direction { case mgmProto.FirewallRule_IN: - rules, err = d.addInRules(ip, protocol, port, action, "") + rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") case mgmProto.FirewallRule_OUT: - rules, err = d.addOutRules(ip, protocol, port, action, "") + rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") default: return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") } @@ -205,9 +251,17 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri return ruleID, rules, nil } -func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) { +func (d *DefaultManager) addInRules( + ip net.IP, + protocol firewall.Protocol, + port *firewall.Port, + action firewall.Action, + ipsetName string, + comment string, +) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment) + rule, err := d.manager.AddFiltering( + ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) } @@ -217,7 +271,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port return rules, nil } - rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment) + rule, err = d.manager.AddFiltering( + ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) } @@ -225,9 +280,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port return append(rules, rule), nil } -func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) { +func (d *DefaultManager) addOutRules( + ip net.IP, + protocol firewall.Protocol, + port *firewall.Port, + action firewall.Action, + ipsetName string, + comment string, +) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment) + rule, err := d.manager.AddFiltering( + ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) } @@ -237,7 +300,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port return rules, nil } - rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment) + rule, err = d.manager.AddFiltering( + ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) } @@ -282,6 +346,10 @@ func (d *DefaultManager) squashAcceptRules( in := protoMatch{} out := protoMatch{} + // trace which type of protocols was squashed + squashedRules := []*mgmProto.FirewallRule{} + squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{} + // this function we use to do calculation, can we squash the rules by protocol or not. // We summ amount of Peers IP for given protocol we found in original rules list. // But we zeroed the IP's for protocol if: @@ -298,12 +366,22 @@ func (d *DefaultManager) squashAcceptRules( if _, ok := protocols[r.Protocol]; !ok { protocols[r.Protocol] = map[string]int{} } - match := protocols[r.Protocol] - if _, ok := match[r.PeerIP]; ok { + // special case, when we recieve this all network IP address + // it means that rules for that protocol was already optimized on the + // management side + if r.PeerIP == "0.0.0.0" { + squashedRules = append(squashedRules, r) + squashedProtocols[r.Protocol] = struct{}{} return } - match[r.PeerIP] = i + + ipset := protocols[r.Protocol] + + if _, ok := ipset[r.PeerIP]; ok { + return + } + ipset[r.PeerIP] = i } for i, r := range networkMap.FirewallRules { @@ -324,9 +402,6 @@ func (d *DefaultManager) squashAcceptRules( mgmProto.FirewallRule_UDP, } - // trace which type of protocols was squashed - squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{} squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) { for _, protocol := range protocolOrders { if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { @@ -382,6 +457,11 @@ func (d *DefaultManager) squashAcceptRules( return append(rules, squashedRules...), squashedProtocols } +// getRuleGroupingSelector takes all rule properties except IP address to build selector +func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { + return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) +} + func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol { switch protocol { case mgmProto.FirewallRule_TCP: diff --git a/client/internal/acl/manager_create.go b/client/internal/acl/manager_create.go index 4987ec587..7d9e6b430 100644 --- a/client/internal/acl/manager_create.go +++ b/client/internal/acl/manager_create.go @@ -6,7 +6,6 @@ import ( "fmt" "runtime" - "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/uspfilter" ) @@ -18,10 +17,7 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) { if err != nil { return nil, err } - return &DefaultManager{ - manager: fm, - rulesPairs: make(map[string][]firewall.Rule), - }, nil + return newDefaultManager(fm), nil } return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } diff --git a/client/internal/acl/manager_create_linux.go b/client/internal/acl/manager_create_linux.go index 168114fc2..de4e8adb9 100644 --- a/client/internal/acl/manager_create_linux.go +++ b/client/internal/acl/manager_create_linux.go @@ -29,8 +29,5 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) { } } - return &DefaultManager{ - manager: fm, - rulesPairs: make(map[string][]firewall.Rule), - }, nil + return newDefaultManager(fm), nil } diff --git a/management/server/policy.go b/management/server/policy.go index eb26c8202..54158eeac 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -2,9 +2,11 @@ package server import ( _ "embed" - "fmt" + "strconv" "strings" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" @@ -240,7 +242,15 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun peersExists := make(map[string]struct{}) rules := make([]*FirewallRule, 0) peers := make([]*Peer, 0) + + all, err := a.GetGroupAll() + if err != nil { + log.Errorf("failed to get group all: %v", err) + all = &Group{} + } + return func(rule *PolicyRule, groupPeers []*Peer, direction int) { + isAll := (len(all.Peers) - 1) == len(groupPeers) for _, peer := range groupPeers { if peer == nil { continue @@ -250,29 +260,33 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun peersExists[peer.ID] = struct{}{} } - fwRule := FirewallRule{ + fr := FirewallRule{ PeerIP: peer.IP.String(), Direction: direction, Action: string(rule.Action), Protocol: string(rule.Protocol), } - ruleID := fmt.Sprintf("%s%d", peer.ID+peer.IP.String(), direction) - ruleID += string(rule.Protocol) + string(rule.Action) + strings.Join(rule.Ports, ",") + if isAll { + fr.PeerIP = "0.0.0.0" + } + + ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) + + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")) if _, ok := rulesExists[ruleID]; ok { continue } rulesExists[ruleID] = struct{}{} if len(rule.Ports) == 0 { - rules = append(rules, &fwRule) + rules = append(rules, &fr) continue } for _, port := range rule.Ports { - addRule := fwRule - addRule.Port = port - rules = append(rules, &addRule) + pr := fr // clone rule and add set new port + pr.Port = port + rules = append(rules, &pr) } } }, func() ([]*Peer, []*FirewallRule) { diff --git a/management/server/policy_test.go b/management/server/policy_test.go index d154e54f1..bf003ffac 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -126,6 +126,20 @@ func TestAccount_getPeersByPolicy(t *testing.T) { assert.Contains(t, peers, account.Peers["peerF"]) epectedFirewallRules := []*FirewallRule{ + { + PeerIP: "0.0.0.0", + Direction: firewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + }, + { + PeerIP: "0.0.0.0", + Direction: firewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + }, { PeerIP: "100.65.14.88", Direction: firewallRuleDirectionIN,