Add default firewall rule to allow netbird traffic (#1056)

Add a default firewall rule to allow netbird traffic to be handled 
by the access control managers.

Userspace manager behavior:
- When running on Windows, a default rule is add on Windows firewall
- For Linux, we are using one of the Kernel managers to add a single rule
- This PR doesn't handle macOS

Kernel manager behavior:
- For NFtables, if there is a filter table, an INPUT rule is added
- Iptables follows the previous flow if running on kernel mode. If running 
on userspace mode, it adds a single rule for INPUT and OUTPUT chains

A new checkerFW package has been introduced to consolidate checks across
route and access control managers.
It supports a new environment variable to skip nftables and allow iptables tests
This commit is contained in:
Givi Khojanashvili 2023-09-05 23:07:32 +04:00 committed by GitHub
parent e4bc76c4de
commit 246abda46d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 568 additions and 153 deletions

1
.gitignore vendored
View File

@ -19,3 +19,4 @@ client/.distfiles/
infrastructure_files/setup.env infrastructure_files/setup.env
infrastructure_files/setup-*.env infrastructure_files/setup-*.env
.vscode .vscode
.DS_Store

View File

@ -40,6 +40,9 @@ const (
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality // Netbird client for ACL and routing functionality
type Manager interface { type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set

View File

@ -44,6 +44,7 @@ type Manager struct {
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() iface.WGAddress
IsUserspaceBind() bool
} }
type ruleset struct { type ruleset struct {
@ -52,7 +53,7 @@ type ruleset struct {
} }
// Create iptables firewall manager // Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
m := &Manager{ m := &Manager{
wgIface: wgIface, wgIface: wgIface,
inputDefaultRuleSpecs: []string{ inputDefaultRuleSpecs: []string{
@ -62,26 +63,26 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
rulesets: make(map[string]ruleset), rulesets: make(map[string]ruleset),
} }
if err := ipset.Init(); err != nil { err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("init ipset: %w", err) return nil, fmt.Errorf("init ipset: %w", err)
} }
// init clients for booth ipv4 and ipv6 // init clients for booth ipv4 and ipv6
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported") return nil, fmt.Errorf("iptables is not installed in the system or not supported")
} }
if isIptablesClientAvailable(ipv4Client) {
m.ipv4Client = ipv4Client if ipv6Supported {
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
}
} }
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) if m.ipv4Client == nil && m.ipv6Client == nil {
if err != nil { return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
} else {
if isIptablesClientAvailable(ipv6Client) {
m.ipv6Client = ipv6Client
}
} }
if err := m.Reset(); err != nil { if err := m.Reset(); err != nil {
@ -90,11 +91,6 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment is empty rule ID is used as comment // If comment is empty rule ID is used as comment
@ -276,6 +272,38 @@ func (m *Manager) Reset() error {
return nil return nil
} }
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if m.wgIface.IsUserspaceBind() {
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"allow netbird interface traffic",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionOUT,
fw.ActionAccept,
"",
"allow netbird interface traffic",
)
return err
}
return nil
}
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
@ -406,7 +434,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err) return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
} }
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil { if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err) return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
} }

View File

@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
panic("AddressFunc is not set") panic("AddressFunc is not set")
} }
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) { func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err) require.NoError(t, err)
@ -53,7 +55,7 @@ func TestIptablesManager(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, true)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@ -141,7 +143,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, true)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@ -229,7 +231,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, true)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@ -29,6 +29,8 @@ const (
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets // FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
FilterOutputChainName = "netbird-acl-output-filter" FilterOutputChainName = "netbird-acl-output-filter"
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
@ -379,7 +381,7 @@ func (m *Manager) chain(
if c != nil { if c != nil {
return c, nil return c, nil
} }
return m.createChainIfNotExists(tf, name, hook, priority, cType) return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
} }
if ip.To4() != nil { if ip.To4() != nil {
@ -399,13 +401,20 @@ func (m *Manager) chain(
} }
// table returns the table for the given family of the IP address // table returns the table for the given family of the IP address
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { func (m *Manager) table(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
// we cache access to Netbird ACL table only
if tableName != FilterTableName {
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
}
if family == nftables.TableFamilyIPv4 { if family == nftables.TableFamilyIPv4 {
if m.tableIPv4 != nil { if m.tableIPv4 != nil {
return m.tableIPv4, nil return m.tableIPv4, nil
} }
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4) table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -417,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
return m.tableIPv6, nil return m.tableIPv6, nil
} }
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6) table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -425,19 +434,21 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
return m.tableIPv6, nil return m.tableIPv6, nil
} }
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) { func (m *Manager) createTableIfNotExists(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family) tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of tables: %w", err) return nil, fmt.Errorf("list of tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == FilterTableName { if t.Name == tableName {
return t, nil return t, nil
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}) table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return nil, err return nil, err
} }
@ -446,12 +457,13 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
func (m *Manager) createChainIfNotExists( func (m *Manager) createChainIfNotExists(
family nftables.TableFamily, family nftables.TableFamily,
tableName string,
name string, name string,
hooknum nftables.ChainHook, hooknum nftables.ChainHook,
priority nftables.ChainPriority, priority nftables.ChainPriority,
chainType nftables.ChainType, chainType nftables.ChainType,
) (*nftables.Chain, error) { ) (*nftables.Chain, error) {
table, err := m.table(family) table, err := m.table(family, tableName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -638,6 +650,22 @@ func (m *Manager) Reset() error {
return fmt.Errorf("list of chains: %w", err) return fmt.Errorf("list of chains: %w", err)
} }
for _, c := range chains { for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.rConn.DelChain(c) m.rConn.DelChain(c)
} }
@ -702,6 +730,53 @@ func (m *Manager) Flush() error {
return nil return nil
} }
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
tf := nftables.TableFamilyIPv4
if m.wgIface.Address().IP.To4() == nil {
tf = nftables.TableFamilyIPv6
}
chains, err := m.rConn.ListChainsOfTableFamily(tf)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c
break
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skiping add allow netbird rule")
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}
m.applyAllowNetbirdRules(chain)
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
}
return nil
}
func (m *Manager) flushWithBackoff() (err error) { func (m *Manager) flushWithBackoff() (err error) {
backoff := 4 backoff := 4
backoffTime := 1000 * time.Millisecond backoffTime := 1000 * time.Millisecond
@ -745,6 +820,44 @@ func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chai
return nil return nil
} }
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(AllowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
}
func encodePort(port fw.Port) []byte { func encodePort(port fw.Port) []byte {
bs := make([]byte, 2) bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0])) binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))

View File

@ -0,0 +1,19 @@
//go:build !windows && !linux
package uspfilter
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}

View File

@ -0,0 +1,21 @@
package uspfilter
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.resetHook != nil {
return m.resetHook()
}
return nil
}

View File

@ -0,0 +1,91 @@
package uspfilter
import (
"errors"
"fmt"
"os/exec"
"strings"
"syscall"
)
type action string
const (
addRule action = "add"
deleteRule action = "delete"
firewallRuleName = "Netbird"
noRulesMatchCriteria = "No rules match the specified criteria"
)
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
return fmt.Errorf("couldn't remove windows firewall: %w", err)
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
)
}
func manageFirewallRule(ruleName string, action action, args ...string) error {
active, err := isFirewallRuleActive(ruleName)
if err != nil {
return err
}
if (action == addRule && !active) || (action == deleteRule && active) {
baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
args := append(baseArgs, args...)
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run()
}
return nil
}
func isFirewallRuleActive(ruleName string) (bool, error) {
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
output, err := cmd.Output()
if err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
// if the firewall rule is not active, we expect last exit code to be 1
exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus()
if exitStatus == 1 {
if strings.Contains(string(output), noRulesMatchCriteria) {
return false, nil
}
}
}
return false, err
}
if strings.Contains(string(output), noRulesMatchCriteria) {
return false, nil
}
return true, nil
}

View File

@ -19,6 +19,7 @@ const layerTypeAll = 0
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(iface.PacketFilter) error SetFilter(iface.PacketFilter) error
Address() iface.WGAddress
} }
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
@ -30,6 +31,8 @@ type Manager struct {
incomingRules map[string]RuleSet incomingRules map[string]RuleSet
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface IFaceMapper
resetHook func() error
mutex sync.RWMutex mutex sync.RWMutex
} }
@ -65,6 +68,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
}, },
outgoingRules: make(map[string]RuleSet), outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet),
wgIface: iface,
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
@ -171,17 +175,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return nil return nil
} }
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
@ -375,3 +368,8 @@ func (m *Manager) RemovePacketHook(hookID string) error {
} }
return fmt.Errorf("hook with given id not found") return fmt.Errorf("hook with given id not found")
} }
// SetResetHook which will be executed in the end of Reset method
func (m *Manager) SetResetHook(hook func() error) {
m.resetHook = hook
}

View File

@ -16,6 +16,7 @@ import (
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(iface.PacketFilter) error SetFilterFunc func(iface.PacketFilter) error
AddressFunc func() iface.WGAddress
} }
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
return i.SetFilterFunc(iface) return i.SetFilterFunc(iface)
} }
func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil {
return iface.WGAddress{}
}
return i.AddressFunc()
}
func TestManagerCreate(t *testing.T) { func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },

View File

@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
) )
@ -17,6 +19,9 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return newDefaultManager(fm), nil return newDefaultManager(fm), nil
} }
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)

View File

@ -7,27 +7,69 @@ import (
"github.com/netbirdio/netbird/client/firewall/iptables" "github.com/netbirdio/netbird/client/firewall/iptables"
"github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/checkfw"
) )
// Create creates a firewall manager instance for the Linux // Create creates a firewall manager instance for the Linux
func Create(iface IFaceMapper) (manager *DefaultManager, err error) { func Create(iface IFaceMapper) (*DefaultManager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
var fm firewall.Manager var fm firewall.Manager
var err error
checkResult := checkfw.Check()
switch checkResult {
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
log.Debug("creating an iptables firewall manager for access control")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
log.Infof("failed to create iptables manager for access control: %s", err)
}
case checkfw.NFTABLES:
log.Debug("creating an nftables firewall manager for access control")
if fm, err = nftables.Create(iface); err != nil {
log.Debugf("failed to create nftables manager for access control: %s", err)
}
}
var resetHookForUserspace func() error
if fm != nil && err == nil {
// err shadowing is used here, to ignore this error
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
resetHookForUserspace = fm.Reset
}
if iface.IsUserspaceBind() { if iface.IsUserspaceBind() {
// use userspace packet filtering firewall // use userspace packet filtering firewall
if fm, err = uspfilter.Create(iface); err != nil { usfm, err := uspfilter.Create(iface)
if err != nil {
log.Debugf("failed to create userspace filtering firewall: %s", err) log.Debugf("failed to create userspace filtering firewall: %s", err)
return nil, err return nil, err
} }
} else {
if fm, err = nftables.Create(iface); err != nil { // set kernel space firewall Reset as hook for userspace firewall
log.Debugf("failed to create nftables manager: %s", err) // manager Reset method, to clean up
// fallback to iptables if resetHookForUserspace != nil {
if fm, err = iptables.Create(iface); err != nil { usfm.SetResetHook(resetHookForUserspace)
log.Errorf("failed to create iptables manager: %s", err) }
// to be consistent for any future extensions.
// ignore this error
if err := usfm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
fm = usfm
}
if fm == nil || err != nil {
log.Errorf("failed to create firewall manager: %s", err)
// no firewall manager found or initialized correctly
return nil, err return nil, err
} }
}
}
return newDefaultManager(fm), nil return newDefaultManager(fm), nil
} }

View File

@ -1,11 +1,13 @@
package acl package acl
import ( import (
"net"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@ -32,13 +34,22 @@ func TestDefaultManager(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
iface := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true) ifaceMock.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo") ifaceMock.EXPECT().SetFilter(gomock.Any())
iface.EXPECT().SetFilter(gomock.Any()) ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
acl, err := Create(iface) acl, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("create ACL manager: %v", err) t.Errorf("create ACL manager: %v", err)
return return
@ -311,13 +322,22 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
iface := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true) ifaceMock.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo") ifaceMock.EXPECT().SetFilter(gomock.Any())
iface.EXPECT().SetFilter(gomock.Any()) ip, network, err := net.ParseCIDR("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
acl, err := Create(iface) acl, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("create ACL manager: %v", err) t.Errorf("create ACL manager: %v", err)
return return

View File

@ -0,0 +1,3 @@
//go:build !linux
package checkfw

View File

@ -0,0 +1,56 @@
//go:build !android
package checkfw
import (
"os"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN FWType = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
IPTABLESWITHV6
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func Check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil {
if isIptablesClientAvailable(ip) {
ipSupport := IPTABLES
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if ip6Err == nil {
if isIptablesClientAvailable(ipv6) {
ipSupport = IPTABLESWITHV6
}
}
return ipSupport
}
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@ -7,6 +7,8 @@ import (
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/checkfw"
) )
const ( const (
@ -26,20 +28,20 @@ func genKey(format string, input string) string {
return fmt.Sprintf(format, input) return fmt.Sprintf(format, input)
} }
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager // newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
func NewFirewall(parentCTX context.Context) (firewallManager, error) { func newFirewall(parentCTX context.Context) (firewallManager, error) {
manager, err := newNFTablesManager(parentCTX) checkResult := checkfw.Check()
if err == nil { switch checkResult {
log.Debugf("nftables firewall manager will be used") case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
return manager, nil log.Debug("creating an iptables firewall manager for route rules")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
return newIptablesManager(parentCTX, ipv6Supported)
case checkfw.NFTABLES:
log.Info("creating an nftables firewall manager for route rules")
return newNFTablesManager(parentCTX), nil
} }
fMgr, err := newIptablesManager(parentCTX)
if err != nil { return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules")
log.Debugf("failed to initialize iptables for root mgr: %s", err)
return nil, err
}
log.Debugf("iptables firewall manager will be used")
return fMgr, nil
} }
func getInPair(pair routerPair) routerPair { func getInPair(pair routerPair) routerPair {

View File

@ -6,9 +6,10 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"runtime"
) )
// NewFirewall returns a nil manager // newFirewall returns a nil manager
func NewFirewall(context.Context) (firewallManager, error) { func newFirewall(context.Context) (firewallManager, error) {
return nil, fmt.Errorf("firewall not supported on this OS") return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS)
} }

View File

@ -49,29 +49,28 @@ type iptablesManager struct {
mux sync.Mutex mux sync.Mutex
} }
func newIptablesManager(parentCtx context.Context) (*iptablesManager, error) { func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err)
} else if !isIptablesClientAvailable(ipv4Client) {
return nil, fmt.Errorf("iptables is missing for ipv4")
}
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Debugf("failed to initialize iptables for ipv6: %s", err)
} else if !isIptablesClientAvailable(ipv6Client) {
log.Infof("iptables is missing for ipv6")
ipv6Client = nil
} }
ctx, cancel := context.WithCancel(parentCtx) ctx, cancel := context.WithCancel(parentCtx)
return &iptablesManager{ manager := &iptablesManager{
ctx: ctx, ctx: ctx,
stop: cancel, stop: cancel,
ipv4Client: ipv4Client, ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string), rules: make(map[string]map[string][]string),
}, nil }
if ipv6Supported {
manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err)
}
}
return manager, nil
} }
// CleanRoutingRules cleans existing iptables resources that we created by the agent // CleanRoutingRules cleans existing iptables resources that we created by the agent
@ -486,8 +485,3 @@ func getIptablesRuleType(table string) string {
} }
return ruleType return ruleType
} }
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@ -16,11 +16,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
t.SkipNow() t.SkipNow()
} }
manager, _ := newIptablesManager(context.TODO()) manager, err := newIptablesManager(context.TODO(), true)
require.NoError(t, err, "should return a valid iptables manager")
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers() err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6") require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")

View File

@ -36,7 +36,7 @@ type DefaultManager struct {
// NewManager returns a new route manager // NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
serverRouter, err := newServerRouter(ctx, wgInterface) srvRouter, err := newServerRouter(ctx, wgInterface)
if err != nil { if err != nil {
log.Errorf("server router is not supported: %s", err) log.Errorf("server router is not supported: %s", err)
} }
@ -46,7 +46,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
clientNetworks: make(map[string]*clientNetwork), clientNetworks: make(map[string]*clientNetwork),
serverRouter: serverRouter, serverRouter: srvRouter,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,

View File

@ -3,11 +3,12 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/pion/transport/v2/stdnet"
"net/netip" "net/netip"
"runtime" "runtime"
"testing" "testing"
"github.com/pion/transport/v2/stdnet"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@ -30,6 +31,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
inputInitRoutes []*route.Route inputInitRoutes []*route.Route
inputRoutes []*route.Route inputRoutes []*route.Route
inputSerial uint64 inputSerial uint64
removeSrvRouter bool
serverRoutesExpected int serverRoutesExpected int
clientNetworkWatchersExpected int clientNetworkWatchersExpected int
}{ }{
@ -117,6 +119,35 @@ func TestManagerUpdateRoutes(t *testing.T) {
serverRoutesExpected: 1, serverRoutesExpected: 1,
clientNetworkWatchersExpected: 1, clientNetworkWatchersExpected: 1,
}, },
{
name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.9.9/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
removeSrvRouter: true,
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 1,
},
{ {
name: "Should Create 1 HA Route and 1 Standalone", name: "Should Create 1 HA Route and 1 Standalone",
inputRoutes: []*route.Route{ inputRoutes: []*route.Route{
@ -385,6 +416,10 @@ func TestManagerUpdateRoutes(t *testing.T) {
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
defer routeManager.Stop() defer routeManager.Stop()
if testCase.removeSrvRouter {
routeManager.serverRouter = nil
}
if len(testCase.inputInitRoutes) > 0 { if len(testCase.inputInitRoutes) > 0 {
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes") require.NoError(t, err, "should update routes with init routes")
@ -395,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
sr := routeManager.serverRouter.(*defaultServerRouter) sr := routeManager.serverRouter.(*defaultServerRouter)
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match") require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
} }

View File

@ -86,10 +86,10 @@ type nftablesManager struct {
mux sync.Mutex mux sync.Mutex
} }
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) { func newNFTablesManager(parentCtx context.Context) *nftablesManager {
ctx, cancel := context.WithCancel(parentCtx) ctx, cancel := context.WithCancel(parentCtx)
mgr := &nftablesManager{ return &nftablesManager{
ctx: ctx, ctx: ctx,
stop: cancel, stop: cancel,
conn: &nftables.Conn{}, conn: &nftables.Conn{},
@ -97,18 +97,6 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
defaultForwardRules: make([]*nftables.Rule, 2), defaultForwardRules: make([]*nftables.Rule, 2),
} }
err := mgr.isSupported()
if err != nil {
return nil, err
}
err = mgr.readFilterTable()
if err != nil {
return nil, err
}
return mgr, nil
} }
// CleanRoutingRules cleans existing nftables rules from the system // CleanRoutingRules cleans existing nftables rules from the system
@ -147,6 +135,10 @@ func (n *nftablesManager) RestoreOrCreateContainers() error {
} }
for _, table := range tables { for _, table := range tables {
if table.Name == "filter" {
n.filterTable = table
continue
}
if table.Name == nftablesTable { if table.Name == nftablesTable {
if table.Family == nftables.TableFamilyIPv4 { if table.Family == nftables.TableFamilyIPv4 {
n.tableIPv4 = table n.tableIPv4 = table
@ -259,21 +251,6 @@ func (n *nftablesManager) refreshRulesMap() error {
return nil return nil
} }
func (n *nftablesManager) readFilterTable() error {
tables, err := n.conn.ListTables()
if err != nil {
return err
}
for _, t := range tables {
if t.Name == "filter" {
n.filterTable = t
return nil
}
}
return nil
}
func (n *nftablesManager) eraseDefaultForwardRule() error { func (n *nftablesManager) eraseDefaultForwardRule() error {
if n.defaultForwardRules[0] == nil { if n.defaultForwardRules[0] == nil {
return nil return nil
@ -544,14 +521,6 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
return nil return nil
} }
func (n *nftablesManager) isSupported() error {
_, err := n.conn.ListChains()
if err != nil {
return fmt.Errorf("nftables is not supported: %s", err)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction // getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch { switch {

View File

@ -10,20 +10,23 @@ import (
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/checkfw"
) )
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
manager, err := newNFTablesManager(context.TODO()) if checkfw.Check() != checkfw.NFTABLES {
if err != nil { t.Skip("nftables not supported on this OS")
t.Fatalf("failed to create nftables manager: %s", err)
} }
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers() err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
@ -126,19 +129,19 @@ func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
} }
func TestNftablesManager_InsertRoutingRules(t *testing.T) { func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range insertRuleTestCases { for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := newNFTablesManager(context.TODO()) manager := newNFTablesManager(context.TODO())
if err != nil {
t.Fatalf("failed to create nftables manager: %s", err)
}
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers() err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair) err = manager.InsertRoutingRules(testCase.inputPair)
@ -226,19 +229,19 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
} }
func TestNftablesManager_RemoveRoutingRules(t *testing.T) { func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range removeRuleTestCases { for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := newNFTablesManager(context.TODO()) manager := newNFTablesManager(context.TODO())
if err != nil {
t.Fatalf("failed to create nftables manager: %s", err)
}
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules() defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers() err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
table := manager.tableIPv4 table := manager.tableIPv4

View File

@ -22,7 +22,7 @@ type defaultServerRouter struct {
} }
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) { func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) {
firewall, err := NewFirewall(ctx) firewall, err := newFirewall(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }