From 246abda46d9047b4ed6c2bd70bcb05b29fd050c3 Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Tue, 5 Sep 2023 23:07:32 +0400 Subject: [PATCH] Add default firewall rule to allow netbird traffic (#1056) Add a default firewall rule to allow netbird traffic to be handled by the access control managers. Userspace manager behavior: - When running on Windows, a default rule is add on Windows firewall - For Linux, we are using one of the Kernel managers to add a single rule - This PR doesn't handle macOS Kernel manager behavior: - For NFtables, if there is a filter table, an INPUT rule is added - Iptables follows the previous flow if running on kernel mode. If running on userspace mode, it adds a single rule for INPUT and OUTPUT chains A new checkerFW package has been introduced to consolidate checks across route and access control managers. It supports a new environment variable to skip nftables and allow iptables tests --- .gitignore | 1 + client/firewall/firewall.go | 3 + client/firewall/iptables/manager_linux.go | 64 ++++++--- .../firewall/iptables/manager_linux_test.go | 8 +- client/firewall/nftables/manager_linux.go | 129 ++++++++++++++++-- client/firewall/uspfilter/allow_netbird.go | 19 +++ .../firewall/uspfilter/allow_netbird_linux.go | 21 +++ .../uspfilter/allow_netbird_windows.go | 91 ++++++++++++ client/firewall/uspfilter/uspfilter.go | 20 ++- client/firewall/uspfilter/uspfilter_test.go | 8 ++ client/internal/acl/manager_create.go | 5 + client/internal/acl/manager_create_linux.go | 62 +++++++-- client/internal/acl/manager_test.go | 40 ++++-- client/internal/checkfw/check.go | 3 + client/internal/checkfw/check_linux.go | 56 ++++++++ .../internal/routemanager/firewall_linux.go | 28 ++-- .../routemanager/firewall_nonlinux.go | 7 +- .../internal/routemanager/iptables_linux.go | 32 ++--- .../routemanager/iptables_linux_test.go | 5 +- client/internal/routemanager/manager.go | 4 +- client/internal/routemanager/manager_test.go | 39 +++++- .../internal/routemanager/nftables_linux.go | 43 +----- .../routemanager/nftables_linux_test.go | 31 +++-- .../routemanager/server_nonandroid.go | 2 +- 24 files changed, 568 insertions(+), 153 deletions(-) create mode 100644 client/firewall/uspfilter/allow_netbird.go create mode 100644 client/firewall/uspfilter/allow_netbird_linux.go create mode 100644 client/firewall/uspfilter/allow_netbird_windows.go create mode 100644 client/internal/checkfw/check.go create mode 100644 client/internal/checkfw/check_linux.go diff --git a/.gitignore b/.gitignore index 50bbbbe3f..dc62780ad 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ client/.distfiles/ infrastructure_files/setup.env infrastructure_files/setup-*.env .vscode +.DS_Store \ No newline at end of file diff --git a/client/firewall/firewall.go b/client/firewall/firewall.go index 5d003e2f0..59e672a45 100644 --- a/client/firewall/firewall.go +++ b/client/firewall/firewall.go @@ -40,6 +40,9 @@ const ( // It declares methods which handle actions required by the // Netbird client for ACL and routing functionality type Manager interface { + // AllowNetbird allows netbird interface traffic + AllowNetbird() error + // AddFiltering rule to the firewall // // If comment argument is empty firewall manager should set diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index fa51122af..753282d87 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -44,6 +44,7 @@ type Manager struct { type iFaceMapper interface { Name() string Address() iface.WGAddress + IsUserspaceBind() bool } type ruleset struct { @@ -52,7 +53,7 @@ type ruleset struct { } // Create iptables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) { m := &Manager{ wgIface: wgIface, inputDefaultRuleSpecs: []string{ @@ -62,26 +63,26 @@ func Create(wgIface iFaceMapper) (*Manager, error) { rulesets: make(map[string]ruleset), } - if err := ipset.Init(); err != nil { + err := ipset.Init() + if err != nil { return nil, fmt.Errorf("init ipset: %w", err) } // init clients for booth ipv4 and ipv6 - ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, fmt.Errorf("iptables is not installed in the system or not supported") } - if isIptablesClientAvailable(ipv4Client) { - m.ipv4Client = ipv4Client + + if ipv6Supported { + m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err) + } } - ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) - if err != nil { - log.Errorf("ip6tables is not installed in the system or not supported: %v", err) - } else { - if isIptablesClientAvailable(ipv6Client) { - m.ipv6Client = ipv6Client - } + if m.ipv4Client == nil && m.ipv6Client == nil { + return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it") } if err := m.Reset(); err != nil { @@ -90,11 +91,6 @@ func Create(wgIface iFaceMapper) (*Manager, error) { return m, nil } -func isIptablesClientAvailable(client *iptables.IPTables) bool { - _, err := client.ListChains("filter") - return err == nil -} - // AddFiltering rule to the firewall // // If comment is empty rule ID is used as comment @@ -276,6 +272,38 @@ func (m *Manager) Reset() error { return nil } +// AllowNetbird allows netbird interface traffic +func (m *Manager) AllowNetbird() error { + if m.wgIface.IsUserspaceBind() { + _, err := m.AddFiltering( + net.ParseIP("0.0.0.0"), + "all", + nil, + nil, + fw.RuleDirectionIN, + fw.ActionAccept, + "", + "allow netbird interface traffic", + ) + if err != nil { + return fmt.Errorf("failed to allow netbird interface traffic: %w", err) + } + _, err = m.AddFiltering( + net.ParseIP("0.0.0.0"), + "all", + nil, + nil, + fw.RuleDirectionOUT, + fw.ActionAccept, + "", + "allow netbird interface traffic", + ) + return err + } + + return nil +} + // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } @@ -406,7 +434,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) { return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err) } - if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil { + if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil { return nil, fmt.Errorf("failed to create input chain jump rule: %w", err) } diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 84e27ed14..2d2013aa2 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress { panic("AddressFunc is not set") } +func (i *iFaceMock) IsUserspaceBind() bool { return false } + func TestIptablesManager(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) @@ -53,7 +55,7 @@ func TestIptablesManager(t *testing.T) { } // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, true) require.NoError(t, err) time.Sleep(time.Second) @@ -141,7 +143,7 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, true) require.NoError(t, err) time.Sleep(time.Second) @@ -229,7 +231,7 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, true) require.NoError(t, err) time.Sleep(time.Second) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 081aee48d..2273f4edc 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -29,6 +29,8 @@ const ( // FilterOutputChainName is the name of the chain that is used for filtering outgoing packets FilterOutputChainName = "netbird-acl-output-filter" + + AllowNetbirdInputRuleID = "allow Netbird incoming traffic" ) var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} @@ -379,7 +381,7 @@ func (m *Manager) chain( if c != nil { return c, nil } - return m.createChainIfNotExists(tf, name, hook, priority, cType) + return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType) } if ip.To4() != nil { @@ -399,13 +401,20 @@ func (m *Manager) chain( } // table returns the table for the given family of the IP address -func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { +func (m *Manager) table( + family nftables.TableFamily, tableName string, +) (*nftables.Table, error) { + // we cache access to Netbird ACL table only + if tableName != FilterTableName { + return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName) + } + if family == nftables.TableFamilyIPv4 { if m.tableIPv4 != nil { return m.tableIPv4, nil } - table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4) + table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName) if err != nil { return nil, err } @@ -417,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { return m.tableIPv6, nil } - table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6) + table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName) if err != nil { return nil, err } @@ -425,19 +434,21 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { return m.tableIPv6, nil } -func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) { +func (m *Manager) createTableIfNotExists( + family nftables.TableFamily, tableName string, +) (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(family) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } for _, t := range tables { - if t.Name == FilterTableName { + if t.Name == tableName { return t, nil } } - table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) if err := m.rConn.Flush(); err != nil { return nil, err } @@ -446,12 +457,13 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables func (m *Manager) createChainIfNotExists( family nftables.TableFamily, + tableName string, name string, hooknum nftables.ChainHook, priority nftables.ChainPriority, chainType nftables.ChainType, ) (*nftables.Chain, error) { - table, err := m.table(family) + table, err := m.table(family, tableName) if err != nil { return nil, err } @@ -638,6 +650,22 @@ func (m *Manager) Reset() error { return fmt.Errorf("list of chains: %w", err) } for _, c := range chains { + // delete Netbird allow input traffic rule if it exists + if c.Table.Name == "filter" && c.Name == "INPUT" { + rules, err := m.rConn.GetRules(c.Table, c) + if err != nil { + log.Errorf("get rules for chain %q: %v", c.Name, err) + continue + } + for _, r := range rules { + if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) { + if err := m.rConn.DelRule(r); err != nil { + log.Errorf("delete rule: %v", err) + } + } + } + } + if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { m.rConn.DelChain(c) } @@ -702,6 +730,53 @@ func (m *Manager) Flush() error { return nil } +// AllowNetbird allows netbird interface traffic +func (m *Manager) AllowNetbird() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + tf := nftables.TableFamilyIPv4 + if m.wgIface.Address().IP.To4() == nil { + tf = nftables.TableFamilyIPv6 + } + + chains, err := m.rConn.ListChainsOfTableFamily(tf) + if err != nil { + return fmt.Errorf("list of chains: %w", err) + } + + var chain *nftables.Chain + for _, c := range chains { + if c.Table.Name == "filter" && c.Name == "INPUT" { + chain = c + break + } + } + + if chain == nil { + log.Debugf("chain INPUT not found. Skiping add allow netbird rule") + return nil + } + + rules, err := m.rConn.GetRules(chain.Table, chain) + if err != nil { + return fmt.Errorf("failed to get rules for the INPUT chain: %v", err) + } + + if rule := m.detectAllowNetbirdRule(rules); rule != nil { + log.Debugf("allow netbird rule already exists: %v", rule) + return nil + } + + m.applyAllowNetbirdRules(chain) + + err = m.rConn.Flush() + if err != nil { + return fmt.Errorf("failed to flush allow input netbird rules: %v", err) + } + return nil +} + func (m *Manager) flushWithBackoff() (err error) { backoff := 4 backoffTime := 1000 * time.Millisecond @@ -745,6 +820,44 @@ func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chai return nil } +func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { + rule := &nftables.Rule{ + Table: chain.Table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + UserData: []byte(AllowNetbirdInputRuleID), + } + _ = m.rConn.InsertRule(rule) +} + +func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { + ifName := ifname(m.wgIface.Name()) + for _, rule := range existedRules { + if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" { + if len(rule.Exprs) < 4 { + if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { + continue + } + if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) { + continue + } + return rule + } + } + } + return nil +} + func encodePort(port fw.Port) []byte { bs := make([]byte, 2) binary.BigEndian.PutUint16(bs, uint16(port.Values[0])) diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go new file mode 100644 index 000000000..ccfef1861 --- /dev/null +++ b/client/firewall/uspfilter/allow_netbird.go @@ -0,0 +1,19 @@ +//go:build !windows && !linux + +package uspfilter + +// Reset firewall to the default state +func (m *Manager) Reset() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.outgoingRules = make(map[string]RuleSet) + m.incomingRules = make(map[string]RuleSet) + + return nil +} + +// AllowNetbird allows netbird interface traffic +func (m *Manager) AllowNetbird() error { + return nil +} diff --git a/client/firewall/uspfilter/allow_netbird_linux.go b/client/firewall/uspfilter/allow_netbird_linux.go new file mode 100644 index 000000000..5df48c756 --- /dev/null +++ b/client/firewall/uspfilter/allow_netbird_linux.go @@ -0,0 +1,21 @@ +package uspfilter + +// AllowNetbird allows netbird interface traffic +func (m *Manager) AllowNetbird() error { + return nil +} + +// Reset firewall to the default state +func (m *Manager) Reset() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.outgoingRules = make(map[string]RuleSet) + m.incomingRules = make(map[string]RuleSet) + + if m.resetHook != nil { + return m.resetHook() + } + + return nil +} diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go new file mode 100644 index 000000000..05a6d22ae --- /dev/null +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -0,0 +1,91 @@ +package uspfilter + +import ( + "errors" + "fmt" + "os/exec" + "strings" + "syscall" +) + +type action string + +const ( + addRule action = "add" + deleteRule action = "delete" + + firewallRuleName = "Netbird" + noRulesMatchCriteria = "No rules match the specified criteria" +) + +// Reset firewall to the default state +func (m *Manager) Reset() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.outgoingRules = make(map[string]RuleSet) + m.incomingRules = make(map[string]RuleSet) + + if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { + return fmt.Errorf("couldn't remove windows firewall: %w", err) + } + + return nil +} + +// AllowNetbird allows netbird interface traffic +func (m *Manager) AllowNetbird() error { + return manageFirewallRule(firewallRuleName, + addRule, + "dir=in", + "enable=yes", + "action=allow", + "profile=any", + "localip="+m.wgIface.Address().IP.String(), + ) +} + +func manageFirewallRule(ruleName string, action action, args ...string) error { + active, err := isFirewallRuleActive(ruleName) + if err != nil { + return err + } + + if (action == addRule && !active) || (action == deleteRule && active) { + baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName} + args := append(baseArgs, args...) + + cmd := exec.Command("netsh", args...) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + return cmd.Run() + } + + return nil +} + +func isFirewallRuleActive(ruleName string) (bool, error) { + args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName} + + cmd := exec.Command("netsh", args...) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + output, err := cmd.Output() + if err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + // if the firewall rule is not active, we expect last exit code to be 1 + exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus() + if exitStatus == 1 { + if strings.Contains(string(output), noRulesMatchCriteria) { + return false, nil + } + } + } + return false, err + } + + if strings.Contains(string(output), noRulesMatchCriteria) { + return false, nil + } + + return true, nil +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 3dead1db4..50170b46c 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -19,6 +19,7 @@ const layerTypeAll = 0 // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { SetFilter(iface.PacketFilter) error + Address() iface.WGAddress } // RuleSet is a set of rules grouped by a string key @@ -30,6 +31,8 @@ type Manager struct { incomingRules map[string]RuleSet wgNetwork *net.IPNet decoders sync.Pool + wgIface IFaceMapper + resetHook func() error mutex sync.RWMutex } @@ -65,6 +68,7 @@ func Create(iface IFaceMapper) (*Manager, error) { }, outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), + wgIface: iface, } if err := iface.SetFilter(m); err != nil { @@ -171,17 +175,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error { return nil } -// Reset firewall to the default state -func (m *Manager) Reset() error { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.outgoingRules = make(map[string]RuleSet) - m.incomingRules = make(map[string]RuleSet) - - return nil -} - // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } @@ -375,3 +368,8 @@ func (m *Manager) RemovePacketHook(hookID string) error { } return fmt.Errorf("hook with given id not found") } + +// SetResetHook which will be executed in the end of Reset method +func (m *Manager) SetResetHook(hook func() error) { + m.resetHook = hook +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index bc94f59c1..6b3d334a8 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -16,6 +16,7 @@ import ( type IFaceMock struct { SetFilterFunc func(iface.PacketFilter) error + AddressFunc func() iface.WGAddress } func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { @@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { return i.SetFilterFunc(iface) } +func (i *IFaceMock) Address() iface.WGAddress { + if i.AddressFunc == nil { + return iface.WGAddress{} + } + return i.AddressFunc() +} + func TestManagerCreate(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(iface.PacketFilter) error { return nil }, diff --git a/client/internal/acl/manager_create.go b/client/internal/acl/manager_create.go index 7d9e6b430..c573d2c64 100644 --- a/client/internal/acl/manager_create.go +++ b/client/internal/acl/manager_create.go @@ -6,6 +6,8 @@ import ( "fmt" "runtime" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter" ) @@ -17,6 +19,9 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) { if err != nil { return nil, err } + if err := fm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) + } return newDefaultManager(fm), nil } return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) diff --git a/client/internal/acl/manager_create_linux.go b/client/internal/acl/manager_create_linux.go index de4e8adb9..4342463d3 100644 --- a/client/internal/acl/manager_create_linux.go +++ b/client/internal/acl/manager_create_linux.go @@ -7,26 +7,68 @@ import ( "github.com/netbirdio/netbird/client/firewall/iptables" "github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/checkfw" ) // Create creates a firewall manager instance for the Linux -func Create(iface IFaceMapper) (manager *DefaultManager, err error) { +func Create(iface IFaceMapper) (*DefaultManager, error) { + // on the linux system we try to user nftables or iptables + // in any case, because we need to allow netbird interface traffic + // so we use AllowNetbird traffic from these firewall managers + // for the userspace packet filtering firewall var fm firewall.Manager + var err error + + checkResult := checkfw.Check() + switch checkResult { + case checkfw.IPTABLES, checkfw.IPTABLESWITHV6: + log.Debug("creating an iptables firewall manager for access control") + ipv6Supported := checkResult == checkfw.IPTABLESWITHV6 + if fm, err = iptables.Create(iface, ipv6Supported); err != nil { + log.Infof("failed to create iptables manager for access control: %s", err) + } + case checkfw.NFTABLES: + log.Debug("creating an nftables firewall manager for access control") + if fm, err = nftables.Create(iface); err != nil { + log.Debugf("failed to create nftables manager for access control: %s", err) + } + } + + var resetHookForUserspace func() error + if fm != nil && err == nil { + // err shadowing is used here, to ignore this error + if err := fm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) + } + resetHookForUserspace = fm.Reset + } + if iface.IsUserspaceBind() { // use userspace packet filtering firewall - if fm, err = uspfilter.Create(iface); err != nil { + usfm, err := uspfilter.Create(iface) + if err != nil { log.Debugf("failed to create userspace filtering firewall: %s", err) return nil, err } - } else { - if fm, err = nftables.Create(iface); err != nil { - log.Debugf("failed to create nftables manager: %s", err) - // fallback to iptables - if fm, err = iptables.Create(iface); err != nil { - log.Errorf("failed to create iptables manager: %s", err) - return nil, err - } + + // set kernel space firewall Reset as hook for userspace firewall + // manager Reset method, to clean up + if resetHookForUserspace != nil { + usfm.SetResetHook(resetHookForUserspace) } + + // to be consistent for any future extensions. + // ignore this error + if err := usfm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) + } + fm = usfm + } + + if fm == nil || err != nil { + log.Errorf("failed to create firewall manager: %s", err) + // no firewall manager found or initialized correctly + return nil, err } return newDefaultManager(fm), nil diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index d765e5c6c..518e895cf 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -1,11 +1,13 @@ package acl import ( + "net" "testing" "github.com/golang/mock/gomock" "github.com/netbirdio/netbird/client/internal/acl/mocks" + "github.com/netbirdio/netbird/iface" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -32,13 +34,22 @@ func TestDefaultManager(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - iface := mocks.NewMockIFaceMapper(ctrl) - iface.EXPECT().IsUserspaceBind().Return(true) - // iface.EXPECT().Name().Return("lo") - iface.EXPECT().SetFilter(gomock.Any()) + ifaceMock := mocks.NewMockIFaceMapper(ctrl) + ifaceMock.EXPECT().IsUserspaceBind().Return(true) + ifaceMock.EXPECT().SetFilter(gomock.Any()) + ip, network, err := net.ParseCIDR("172.0.0.1/32") + if err != nil { + t.Fatalf("failed to parse IP address: %v", err) + } + + ifaceMock.EXPECT().Name().Return("lo").AnyTimes() + ifaceMock.EXPECT().Address().Return(iface.WGAddress{ + IP: ip, + Network: network, + }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - acl, err := Create(iface) + acl, err := Create(ifaceMock) if err != nil { t.Errorf("create ACL manager: %v", err) return @@ -311,13 +322,22 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - iface := mocks.NewMockIFaceMapper(ctrl) - iface.EXPECT().IsUserspaceBind().Return(true) - // iface.EXPECT().Name().Return("lo") - iface.EXPECT().SetFilter(gomock.Any()) + ifaceMock := mocks.NewMockIFaceMapper(ctrl) + ifaceMock.EXPECT().IsUserspaceBind().Return(true) + ifaceMock.EXPECT().SetFilter(gomock.Any()) + ip, network, err := net.ParseCIDR("172.0.0.1/32") + if err != nil { + t.Fatalf("failed to parse IP address: %v", err) + } + + ifaceMock.EXPECT().Name().Return("lo").AnyTimes() + ifaceMock.EXPECT().Address().Return(iface.WGAddress{ + IP: ip, + Network: network, + }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - acl, err := Create(iface) + acl, err := Create(ifaceMock) if err != nil { t.Errorf("create ACL manager: %v", err) return diff --git a/client/internal/checkfw/check.go b/client/internal/checkfw/check.go new file mode 100644 index 000000000..59626cbc3 --- /dev/null +++ b/client/internal/checkfw/check.go @@ -0,0 +1,3 @@ +//go:build !linux + +package checkfw diff --git a/client/internal/checkfw/check_linux.go b/client/internal/checkfw/check_linux.go new file mode 100644 index 000000000..552d5698c --- /dev/null +++ b/client/internal/checkfw/check_linux.go @@ -0,0 +1,56 @@ +//go:build !android + +package checkfw + +import ( + "os" + + "github.com/coreos/go-iptables/iptables" + "github.com/google/nftables" +) + +const ( + // UNKNOWN is the default value for the firewall type for unknown firewall type + UNKNOWN FWType = iota + // IPTABLES is the value for the iptables firewall type + IPTABLES + // IPTABLESWITHV6 is the value for the iptables firewall type with ipv6 + IPTABLESWITHV6 + // NFTABLES is the value for the nftables firewall type + NFTABLES +) + +// SKIP_NFTABLES_ENV is the environment variable to skip nftables check +const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" + +// FWType is the type for the firewall type +type FWType int + +// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. +func Check() FWType { + nf := nftables.Conn{} + if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" { + return NFTABLES + } + + ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err == nil { + if isIptablesClientAvailable(ip) { + ipSupport := IPTABLES + ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6) + if ip6Err == nil { + if isIptablesClientAvailable(ipv6) { + ipSupport = IPTABLESWITHV6 + } + } + return ipSupport + } + } + + return UNKNOWN +} + +func isIptablesClientAvailable(client *iptables.IPTables) bool { + _, err := client.ListChains("filter") + return err == nil +} diff --git a/client/internal/routemanager/firewall_linux.go b/client/internal/routemanager/firewall_linux.go index 19a5a4cde..50d451a88 100644 --- a/client/internal/routemanager/firewall_linux.go +++ b/client/internal/routemanager/firewall_linux.go @@ -7,6 +7,8 @@ import ( "fmt" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/checkfw" ) const ( @@ -26,20 +28,20 @@ func genKey(format string, input string) string { return fmt.Sprintf(format, input) } -// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager -func NewFirewall(parentCTX context.Context) (firewallManager, error) { - manager, err := newNFTablesManager(parentCTX) - if err == nil { - log.Debugf("nftables firewall manager will be used") - return manager, nil +// newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager +func newFirewall(parentCTX context.Context) (firewallManager, error) { + checkResult := checkfw.Check() + switch checkResult { + case checkfw.IPTABLES, checkfw.IPTABLESWITHV6: + log.Debug("creating an iptables firewall manager for route rules") + ipv6Supported := checkResult == checkfw.IPTABLESWITHV6 + return newIptablesManager(parentCTX, ipv6Supported) + case checkfw.NFTABLES: + log.Info("creating an nftables firewall manager for route rules") + return newNFTablesManager(parentCTX), nil } - fMgr, err := newIptablesManager(parentCTX) - if err != nil { - log.Debugf("failed to initialize iptables for root mgr: %s", err) - return nil, err - } - log.Debugf("iptables firewall manager will be used") - return fMgr, nil + + return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules") } func getInPair(pair routerPair) routerPair { diff --git a/client/internal/routemanager/firewall_nonlinux.go b/client/internal/routemanager/firewall_nonlinux.go index 1b52a1e85..ae0627048 100644 --- a/client/internal/routemanager/firewall_nonlinux.go +++ b/client/internal/routemanager/firewall_nonlinux.go @@ -6,9 +6,10 @@ package routemanager import ( "context" "fmt" + "runtime" ) -// NewFirewall returns a nil manager -func NewFirewall(context.Context) (firewallManager, error) { - return nil, fmt.Errorf("firewall not supported on this OS") +// newFirewall returns a nil manager +func newFirewall(context.Context) (firewallManager, error) { + return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS) } diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go index a87d4f4a3..9f6019305 100644 --- a/client/internal/routemanager/iptables_linux.go +++ b/client/internal/routemanager/iptables_linux.go @@ -49,29 +49,28 @@ type iptablesManager struct { mux sync.Mutex } -func newIptablesManager(parentCtx context.Context) (*iptablesManager, error) { +func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { - return nil, err - } else if !isIptablesClientAvailable(ipv4Client) { - return nil, fmt.Errorf("iptables is missing for ipv4") - } - ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) - if err != nil { - log.Debugf("failed to initialize iptables for ipv6: %s", err) - } else if !isIptablesClientAvailable(ipv6Client) { - log.Infof("iptables is missing for ipv6") - ipv6Client = nil + return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err) } ctx, cancel := context.WithCancel(parentCtx) - return &iptablesManager{ + manager := &iptablesManager{ ctx: ctx, stop: cancel, ipv4Client: ipv4Client, - ipv6Client: ipv6Client, rules: make(map[string]map[string][]string), - }, nil + } + + if ipv6Supported { + manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err) + } + } + + return manager, nil } // CleanRoutingRules cleans existing iptables resources that we created by the agent @@ -486,8 +485,3 @@ func getIptablesRuleType(table string) string { } return ruleType } - -func isIptablesClientAvailable(client *iptables.IPTables) bool { - _, err := client.ListChains("filter") - return err == nil -} diff --git a/client/internal/routemanager/iptables_linux_test.go b/client/internal/routemanager/iptables_linux_test.go index dbe153f7b..4f733de34 100644 --- a/client/internal/routemanager/iptables_linux_test.go +++ b/client/internal/routemanager/iptables_linux_test.go @@ -16,11 +16,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { t.SkipNow() } - manager, _ := newIptablesManager(context.TODO()) + manager, err := newIptablesManager(context.TODO(), true) + require.NoError(t, err, "should return a valid iptables manager") defer manager.CleanRoutingRules() - err := manager.RestoreOrCreateContainers() + err = manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6") diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 13d9d1f38..b31fe6327 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -36,7 +36,7 @@ type DefaultManager struct { // NewManager returns a new route manager func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { - serverRouter, err := newServerRouter(ctx, wgInterface) + srvRouter, err := newServerRouter(ctx, wgInterface) if err != nil { log.Errorf("server router is not supported: %s", err) } @@ -46,7 +46,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, ctx: mCTX, stop: cancel, clientNetworks: make(map[string]*clientNetwork), - serverRouter: serverRouter, + serverRouter: srvRouter, statusRecorder: statusRecorder, wgInterface: wgInterface, pubKey: pubKey, diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 6f2ac294d..f6f5f359e 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -3,11 +3,12 @@ package routemanager import ( "context" "fmt" - "github.com/pion/transport/v2/stdnet" "net/netip" "runtime" "testing" + "github.com/pion/transport/v2/stdnet" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/internal/peer" @@ -30,6 +31,7 @@ func TestManagerUpdateRoutes(t *testing.T) { inputInitRoutes []*route.Route inputRoutes []*route.Route inputSerial uint64 + removeSrvRouter bool serverRoutesExpected int clientNetworkWatchersExpected int }{ @@ -117,6 +119,35 @@ func TestManagerUpdateRoutes(t *testing.T) { serverRoutesExpected: 1, clientNetworkWatchersExpected: 1, }, + { + name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: localPeerKey, + Network: netip.MustParsePrefix("100.64.30.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.9.9/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + removeSrvRouter: true, + serverRoutesExpected: 0, + clientNetworkWatchersExpected: 1, + }, { name: "Should Create 1 HA Route and 1 Standalone", inputRoutes: []*route.Route{ @@ -385,6 +416,10 @@ func TestManagerUpdateRoutes(t *testing.T) { routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) defer routeManager.Stop() + if testCase.removeSrvRouter { + routeManager.serverRouter = nil + } + if len(testCase.inputInitRoutes) > 0 { err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) require.NoError(t, err, "should update routes with init routes") @@ -395,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) { require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" && routeManager.serverRouter != nil { sr := routeManager.serverRouter.(*defaultServerRouter) require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match") } diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go index ca7d74f2a..25dc6e7db 100644 --- a/client/internal/routemanager/nftables_linux.go +++ b/client/internal/routemanager/nftables_linux.go @@ -86,10 +86,10 @@ type nftablesManager struct { mux sync.Mutex } -func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) { +func newNFTablesManager(parentCtx context.Context) *nftablesManager { ctx, cancel := context.WithCancel(parentCtx) - mgr := &nftablesManager{ + return &nftablesManager{ ctx: ctx, stop: cancel, conn: &nftables.Conn{}, @@ -97,18 +97,6 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) { rules: make(map[string]*nftables.Rule), defaultForwardRules: make([]*nftables.Rule, 2), } - - err := mgr.isSupported() - if err != nil { - return nil, err - } - - err = mgr.readFilterTable() - if err != nil { - return nil, err - } - - return mgr, nil } // CleanRoutingRules cleans existing nftables rules from the system @@ -147,6 +135,10 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { } for _, table := range tables { + if table.Name == "filter" { + n.filterTable = table + continue + } if table.Name == nftablesTable { if table.Family == nftables.TableFamilyIPv4 { n.tableIPv4 = table @@ -259,21 +251,6 @@ func (n *nftablesManager) refreshRulesMap() error { return nil } -func (n *nftablesManager) readFilterTable() error { - tables, err := n.conn.ListTables() - if err != nil { - return err - } - - for _, t := range tables { - if t.Name == "filter" { - n.filterTable = t - return nil - } - } - return nil -} - func (n *nftablesManager) eraseDefaultForwardRule() error { if n.defaultForwardRules[0] == nil { return nil @@ -544,14 +521,6 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro return nil } -func (n *nftablesManager) isSupported() error { - _, err := n.conn.ListChains() - if err != nil { - return fmt.Errorf("nftables is not supported: %s", err) - } - return nil -} - // getPayloadDirectives get expression directives based on ip version and direction func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { switch { diff --git a/client/internal/routemanager/nftables_linux_test.go b/client/internal/routemanager/nftables_linux_test.go index 01fc38885..dec800156 100644 --- a/client/internal/routemanager/nftables_linux_test.go +++ b/client/internal/routemanager/nftables_linux_test.go @@ -10,20 +10,23 @@ import ( "github.com/google/nftables/expr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/checkfw" ) func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { - manager, err := newNFTablesManager(context.TODO()) - if err != nil { - t.Fatalf("failed to create nftables manager: %s", err) + if checkfw.Check() != checkfw.NFTABLES { + t.Skip("nftables not supported on this OS") } + manager := newNFTablesManager(context.TODO()) + nftablesTestingClient := &nftables.Conn{} defer manager.CleanRoutingRules() - err = manager.RestoreOrCreateContainers() + err := manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") @@ -126,19 +129,19 @@ func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { } func TestNftablesManager_InsertRoutingRules(t *testing.T) { + if checkfw.Check() != checkfw.NFTABLES { + t.Skip("nftables not supported on this OS") + } for _, testCase := range insertRuleTestCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := newNFTablesManager(context.TODO()) - if err != nil { - t.Fatalf("failed to create nftables manager: %s", err) - } + manager := newNFTablesManager(context.TODO()) nftablesTestingClient := &nftables.Conn{} defer manager.CleanRoutingRules() - err = manager.RestoreOrCreateContainers() + err := manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") err = manager.InsertRoutingRules(testCase.inputPair) @@ -226,19 +229,19 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { } func TestNftablesManager_RemoveRoutingRules(t *testing.T) { + if checkfw.Check() != checkfw.NFTABLES { + t.Skip("nftables not supported on this OS") + } for _, testCase := range removeRuleTestCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := newNFTablesManager(context.TODO()) - if err != nil { - t.Fatalf("failed to create nftables manager: %s", err) - } + manager := newNFTablesManager(context.TODO()) nftablesTestingClient := &nftables.Conn{} defer manager.CleanRoutingRules() - err = manager.RestoreOrCreateContainers() + err := manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") table := manager.tableIPv4 diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index bf7a1dfd4..6df632329 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -22,7 +22,7 @@ type defaultServerRouter struct { } func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) { - firewall, err := NewFirewall(ctx) + firewall, err := newFirewall(ctx) if err != nil { return nil, err }