mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-21 23:53:14 +01:00
Add default firewall rule to allow netbird traffic (#1056)
Add a default firewall rule to allow netbird traffic to be handled by the access control managers. Userspace manager behavior: - When running on Windows, a default rule is add on Windows firewall - For Linux, we are using one of the Kernel managers to add a single rule - This PR doesn't handle macOS Kernel manager behavior: - For NFtables, if there is a filter table, an INPUT rule is added - Iptables follows the previous flow if running on kernel mode. If running on userspace mode, it adds a single rule for INPUT and OUTPUT chains A new checkerFW package has been introduced to consolidate checks across route and access control managers. It supports a new environment variable to skip nftables and allow iptables tests
This commit is contained in:
parent
e4bc76c4de
commit
246abda46d
1
.gitignore
vendored
1
.gitignore
vendored
@ -19,3 +19,4 @@ client/.distfiles/
|
|||||||
infrastructure_files/setup.env
|
infrastructure_files/setup.env
|
||||||
infrastructure_files/setup-*.env
|
infrastructure_files/setup-*.env
|
||||||
.vscode
|
.vscode
|
||||||
|
.DS_Store
|
@ -40,6 +40,9 @@ const (
|
|||||||
// It declares methods which handle actions required by the
|
// It declares methods which handle actions required by the
|
||||||
// Netbird client for ACL and routing functionality
|
// Netbird client for ACL and routing functionality
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
AllowNetbird() error
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
|
@ -44,6 +44,7 @@ type Manager struct {
|
|||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() iface.WGAddress
|
||||||
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type ruleset struct {
|
type ruleset struct {
|
||||||
@ -52,7 +53,7 @@ type ruleset struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
inputDefaultRuleSpecs: []string{
|
inputDefaultRuleSpecs: []string{
|
||||||
@ -62,26 +63,26 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
rulesets: make(map[string]ruleset),
|
rulesets: make(map[string]ruleset),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ipset.Init(); err != nil {
|
err := ipset.Init()
|
||||||
|
if err != nil {
|
||||||
return nil, fmt.Errorf("init ipset: %w", err)
|
return nil, fmt.Errorf("init ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// init clients for booth ipv4 and ipv6
|
// init clients for booth ipv4 and ipv6
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||||
}
|
}
|
||||||
if isIptablesClientAvailable(ipv4Client) {
|
|
||||||
m.ipv4Client = ipv4Client
|
if ipv6Supported {
|
||||||
|
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
if m.ipv4Client == nil && m.ipv6Client == nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
|
||||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
|
||||||
} else {
|
|
||||||
if isIptablesClientAvailable(ipv6Client) {
|
|
||||||
m.ipv6Client = ipv6Client
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.Reset(); err != nil {
|
if err := m.Reset(); err != nil {
|
||||||
@ -90,11 +91,6 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|
||||||
_, err := client.ListChains("filter")
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFiltering rule to the firewall
|
// AddFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// If comment is empty rule ID is used as comment
|
// If comment is empty rule ID is used as comment
|
||||||
@ -276,6 +272,38 @@ func (m *Manager) Reset() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
if m.wgIface.IsUserspaceBind() {
|
||||||
|
_, err := m.AddFiltering(
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
"all",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.RuleDirectionIN,
|
||||||
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
"allow netbird interface traffic",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
||||||
|
}
|
||||||
|
_, err = m.AddFiltering(
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
"all",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.RuleDirectionOUT,
|
||||||
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
"allow netbird interface traffic",
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
@ -406,7 +434,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
|
|||||||
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
|
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
|
|||||||
panic("AddressFunc is not set")
|
panic("AddressFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||||
|
|
||||||
func TestIptablesManager(t *testing.T) {
|
func TestIptablesManager(t *testing.T) {
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -53,7 +55,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@ -141,7 +143,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@ -229,7 +231,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
@ -29,6 +29,8 @@ const (
|
|||||||
|
|
||||||
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
||||||
FilterOutputChainName = "netbird-acl-output-filter"
|
FilterOutputChainName = "netbird-acl-output-filter"
|
||||||
|
|
||||||
|
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||||
)
|
)
|
||||||
|
|
||||||
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||||
@ -379,7 +381,7 @@ func (m *Manager) chain(
|
|||||||
if c != nil {
|
if c != nil {
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
return m.createChainIfNotExists(tf, name, hook, priority, cType)
|
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.To4() != nil {
|
if ip.To4() != nil {
|
||||||
@ -399,13 +401,20 @@ func (m *Manager) chain(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// table returns the table for the given family of the IP address
|
// table returns the table for the given family of the IP address
|
||||||
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
func (m *Manager) table(
|
||||||
|
family nftables.TableFamily, tableName string,
|
||||||
|
) (*nftables.Table, error) {
|
||||||
|
// we cache access to Netbird ACL table only
|
||||||
|
if tableName != FilterTableName {
|
||||||
|
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
||||||
|
}
|
||||||
|
|
||||||
if family == nftables.TableFamilyIPv4 {
|
if family == nftables.TableFamilyIPv4 {
|
||||||
if m.tableIPv4 != nil {
|
if m.tableIPv4 != nil {
|
||||||
return m.tableIPv4, nil
|
return m.tableIPv4, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
|
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -417,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
|||||||
return m.tableIPv6, nil
|
return m.tableIPv6, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
|
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -425,19 +434,21 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
|||||||
return m.tableIPv6, nil
|
return m.tableIPv6, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
func (m *Manager) createTableIfNotExists(
|
||||||
|
family nftables.TableFamily, tableName string,
|
||||||
|
) (*nftables.Table, error) {
|
||||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list of tables: %w", err)
|
return nil, fmt.Errorf("list of tables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
if t.Name == FilterTableName {
|
if t.Name == tableName {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
|
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -446,12 +457,13 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
|
|||||||
|
|
||||||
func (m *Manager) createChainIfNotExists(
|
func (m *Manager) createChainIfNotExists(
|
||||||
family nftables.TableFamily,
|
family nftables.TableFamily,
|
||||||
|
tableName string,
|
||||||
name string,
|
name string,
|
||||||
hooknum nftables.ChainHook,
|
hooknum nftables.ChainHook,
|
||||||
priority nftables.ChainPriority,
|
priority nftables.ChainPriority,
|
||||||
chainType nftables.ChainType,
|
chainType nftables.ChainType,
|
||||||
) (*nftables.Chain, error) {
|
) (*nftables.Chain, error) {
|
||||||
table, err := m.table(family)
|
table, err := m.table(family, tableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -638,6 +650,22 @@ func (m *Manager) Reset() error {
|
|||||||
return fmt.Errorf("list of chains: %w", err)
|
return fmt.Errorf("list of chains: %w", err)
|
||||||
}
|
}
|
||||||
for _, c := range chains {
|
for _, c := range chains {
|
||||||
|
// delete Netbird allow input traffic rule if it exists
|
||||||
|
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||||
|
rules, err := m.rConn.GetRules(c.Table, c)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, r := range rules {
|
||||||
|
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
|
||||||
|
if err := m.rConn.DelRule(r); err != nil {
|
||||||
|
log.Errorf("delete rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||||
m.rConn.DelChain(c)
|
m.rConn.DelChain(c)
|
||||||
}
|
}
|
||||||
@ -702,6 +730,53 @@ func (m *Manager) Flush() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
tf := nftables.TableFamilyIPv4
|
||||||
|
if m.wgIface.Address().IP.To4() == nil {
|
||||||
|
tf = nftables.TableFamilyIPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
chains, err := m.rConn.ListChainsOfTableFamily(tf)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list of chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var chain *nftables.Chain
|
||||||
|
for _, c := range chains {
|
||||||
|
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||||
|
chain = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if chain == nil {
|
||||||
|
log.Debugf("chain INPUT not found. Skiping add allow netbird rule")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := m.rConn.GetRules(chain.Table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
|
||||||
|
log.Debugf("allow netbird rule already exists: %v", rule)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.applyAllowNetbirdRules(chain)
|
||||||
|
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) flushWithBackoff() (err error) {
|
func (m *Manager) flushWithBackoff() (err error) {
|
||||||
backoff := 4
|
backoff := 4
|
||||||
backoffTime := 1000 * time.Millisecond
|
backoffTime := 1000 * time.Millisecond
|
||||||
@ -745,6 +820,44 @@ func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chai
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||||
|
rule := &nftables.Rule{
|
||||||
|
Table: chain.Table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(m.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: []byte(AllowNetbirdInputRuleID),
|
||||||
|
}
|
||||||
|
_ = m.rConn.InsertRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||||
|
ifName := ifname(m.wgIface.Name())
|
||||||
|
for _, rule := range existedRules {
|
||||||
|
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
|
||||||
|
if len(rule.Exprs) < 4 {
|
||||||
|
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func encodePort(port fw.Port) []byte {
|
func encodePort(port fw.Port) []byte {
|
||||||
bs := make([]byte, 2)
|
bs := make([]byte, 2)
|
||||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
||||||
|
19
client/firewall/uspfilter/allow_netbird.go
Normal file
19
client/firewall/uspfilter/allow_netbird.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
//go:build !windows && !linux
|
||||||
|
|
||||||
|
package uspfilter
|
||||||
|
|
||||||
|
// Reset firewall to the default state
|
||||||
|
func (m *Manager) Reset() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
return nil
|
||||||
|
}
|
21
client/firewall/uspfilter/allow_netbird_linux.go
Normal file
21
client/firewall/uspfilter/allow_netbird_linux.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset firewall to the default state
|
||||||
|
func (m *Manager) Reset() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
if m.resetHook != nil {
|
||||||
|
return m.resetHook()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
91
client/firewall/uspfilter/allow_netbird_windows.go
Normal file
91
client/firewall/uspfilter/allow_netbird_windows.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
type action string
|
||||||
|
|
||||||
|
const (
|
||||||
|
addRule action = "add"
|
||||||
|
deleteRule action = "delete"
|
||||||
|
|
||||||
|
firewallRuleName = "Netbird"
|
||||||
|
noRulesMatchCriteria = "No rules match the specified criteria"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Reset firewall to the default state
|
||||||
|
func (m *Manager) Reset() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
m.outgoingRules = make(map[string]RuleSet)
|
||||||
|
m.incomingRules = make(map[string]RuleSet)
|
||||||
|
|
||||||
|
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||||
|
return fmt.Errorf("couldn't remove windows firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
return manageFirewallRule(firewallRuleName,
|
||||||
|
addRule,
|
||||||
|
"dir=in",
|
||||||
|
"enable=yes",
|
||||||
|
"action=allow",
|
||||||
|
"profile=any",
|
||||||
|
"localip="+m.wgIface.Address().IP.String(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func manageFirewallRule(ruleName string, action action, args ...string) error {
|
||||||
|
active, err := isFirewallRuleActive(ruleName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if (action == addRule && !active) || (action == deleteRule && active) {
|
||||||
|
baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
|
||||||
|
args := append(baseArgs, args...)
|
||||||
|
|
||||||
|
cmd := exec.Command("netsh", args...)
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isFirewallRuleActive(ruleName string) (bool, error) {
|
||||||
|
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
|
||||||
|
|
||||||
|
cmd := exec.Command("netsh", args...)
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
var exitError *exec.ExitError
|
||||||
|
if errors.As(err, &exitError) {
|
||||||
|
// if the firewall rule is not active, we expect last exit code to be 1
|
||||||
|
exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus()
|
||||||
|
if exitStatus == 1 {
|
||||||
|
if strings.Contains(string(output), noRulesMatchCriteria) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(string(output), noRulesMatchCriteria) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
@ -19,6 +19,7 @@ const layerTypeAll = 0
|
|||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
SetFilter(iface.PacketFilter) error
|
SetFilter(iface.PacketFilter) error
|
||||||
|
Address() iface.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
@ -30,6 +31,8 @@ type Manager struct {
|
|||||||
incomingRules map[string]RuleSet
|
incomingRules map[string]RuleSet
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
|
wgIface IFaceMapper
|
||||||
|
resetHook func() error
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
@ -65,6 +68,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
|||||||
},
|
},
|
||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[string]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[string]RuleSet),
|
||||||
|
wgIface: iface,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
@ -171,17 +175,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
|
||||||
func (m *Manager) Reset() error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
@ -375,3 +368,8 @@ func (m *Manager) RemovePacketHook(hookID string) error {
|
|||||||
}
|
}
|
||||||
return fmt.Errorf("hook with given id not found")
|
return fmt.Errorf("hook with given id not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetResetHook which will be executed in the end of Reset method
|
||||||
|
func (m *Manager) SetResetHook(hook func() error) {
|
||||||
|
m.resetHook = hook
|
||||||
|
}
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(iface.PacketFilter) error
|
SetFilterFunc func(iface.PacketFilter) error
|
||||||
|
AddressFunc func() iface.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||||
@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
|||||||
return i.SetFilterFunc(iface)
|
return i.SetFilterFunc(iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) Address() iface.WGAddress {
|
||||||
|
if i.AddressFunc == nil {
|
||||||
|
return iface.WGAddress{}
|
||||||
|
}
|
||||||
|
return i.AddressFunc()
|
||||||
|
}
|
||||||
|
|
||||||
func TestManagerCreate(t *testing.T) {
|
func TestManagerCreate(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,6 +19,9 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err := fm.AllowNetbird(); err != nil {
|
||||||
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
return newDefaultManager(fm), nil
|
return newDefaultManager(fm), nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
|
@ -7,26 +7,68 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/iptables"
|
"github.com/netbirdio/netbird/client/firewall/iptables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/nftables"
|
"github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/checkfw"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create creates a firewall manager instance for the Linux
|
// Create creates a firewall manager instance for the Linux
|
||||||
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
func Create(iface IFaceMapper) (*DefaultManager, error) {
|
||||||
|
// on the linux system we try to user nftables or iptables
|
||||||
|
// in any case, because we need to allow netbird interface traffic
|
||||||
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
|
// for the userspace packet filtering firewall
|
||||||
var fm firewall.Manager
|
var fm firewall.Manager
|
||||||
|
var err error
|
||||||
|
|
||||||
|
checkResult := checkfw.Check()
|
||||||
|
switch checkResult {
|
||||||
|
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
|
||||||
|
log.Debug("creating an iptables firewall manager for access control")
|
||||||
|
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
|
||||||
|
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
|
||||||
|
log.Infof("failed to create iptables manager for access control: %s", err)
|
||||||
|
}
|
||||||
|
case checkfw.NFTABLES:
|
||||||
|
log.Debug("creating an nftables firewall manager for access control")
|
||||||
|
if fm, err = nftables.Create(iface); err != nil {
|
||||||
|
log.Debugf("failed to create nftables manager for access control: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var resetHookForUserspace func() error
|
||||||
|
if fm != nil && err == nil {
|
||||||
|
// err shadowing is used here, to ignore this error
|
||||||
|
if err := fm.AllowNetbird(); err != nil {
|
||||||
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
|
resetHookForUserspace = fm.Reset
|
||||||
|
}
|
||||||
|
|
||||||
if iface.IsUserspaceBind() {
|
if iface.IsUserspaceBind() {
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
if fm, err = uspfilter.Create(iface); err != nil {
|
usfm, err := uspfilter.Create(iface)
|
||||||
|
if err != nil {
|
||||||
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if fm, err = nftables.Create(iface); err != nil {
|
// set kernel space firewall Reset as hook for userspace firewall
|
||||||
log.Debugf("failed to create nftables manager: %s", err)
|
// manager Reset method, to clean up
|
||||||
// fallback to iptables
|
if resetHookForUserspace != nil {
|
||||||
if fm, err = iptables.Create(iface); err != nil {
|
usfm.SetResetHook(resetHookForUserspace)
|
||||||
log.Errorf("failed to create iptables manager: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// to be consistent for any future extensions.
|
||||||
|
// ignore this error
|
||||||
|
if err := usfm.AllowNetbird(); err != nil {
|
||||||
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
|
fm = usfm
|
||||||
|
}
|
||||||
|
|
||||||
|
if fm == nil || err != nil {
|
||||||
|
log.Errorf("failed to create firewall manager: %s", err)
|
||||||
|
// no firewall manager found or initialized correctly
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return newDefaultManager(fm), nil
|
return newDefaultManager(fm), nil
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
|
"github.com/netbirdio/netbird/iface"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,13 +34,22 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
|
||||||
// iface.EXPECT().Name().Return("lo")
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
iface.EXPECT().SetFilter(gomock.Any())
|
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||||
|
IP: ip,
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
acl, err := Create(iface)
|
acl, err := Create(ifaceMock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create ACL manager: %v", err)
|
t.Errorf("create ACL manager: %v", err)
|
||||||
return
|
return
|
||||||
@ -311,13 +322,22 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
|
||||||
// iface.EXPECT().Name().Return("lo")
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
iface.EXPECT().SetFilter(gomock.Any())
|
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||||
|
IP: ip,
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
acl, err := Create(iface)
|
acl, err := Create(ifaceMock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create ACL manager: %v", err)
|
t.Errorf("create ACL manager: %v", err)
|
||||||
return
|
return
|
||||||
|
3
client/internal/checkfw/check.go
Normal file
3
client/internal/checkfw/check.go
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package checkfw
|
56
client/internal/checkfw/check_linux.go
Normal file
56
client/internal/checkfw/check_linux.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package checkfw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/nftables"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||||
|
UNKNOWN FWType = iota
|
||||||
|
// IPTABLES is the value for the iptables firewall type
|
||||||
|
IPTABLES
|
||||||
|
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
|
||||||
|
IPTABLESWITHV6
|
||||||
|
// NFTABLES is the value for the nftables firewall type
|
||||||
|
NFTABLES
|
||||||
|
)
|
||||||
|
|
||||||
|
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||||
|
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||||
|
|
||||||
|
// FWType is the type for the firewall type
|
||||||
|
type FWType int
|
||||||
|
|
||||||
|
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||||
|
func Check() FWType {
|
||||||
|
nf := nftables.Conn{}
|
||||||
|
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||||
|
return NFTABLES
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
if err == nil {
|
||||||
|
if isIptablesClientAvailable(ip) {
|
||||||
|
ipSupport := IPTABLES
|
||||||
|
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
if ip6Err == nil {
|
||||||
|
if isIptablesClientAvailable(ipv6) {
|
||||||
|
ipSupport = IPTABLESWITHV6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ipSupport
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return UNKNOWN
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||||
|
_, err := client.ListChains("filter")
|
||||||
|
return err == nil
|
||||||
|
}
|
@ -7,6 +7,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/checkfw"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -26,20 +28,20 @@ func genKey(format string, input string) string {
|
|||||||
return fmt.Sprintf(format, input)
|
return fmt.Sprintf(format, input)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
// newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
||||||
func NewFirewall(parentCTX context.Context) (firewallManager, error) {
|
func newFirewall(parentCTX context.Context) (firewallManager, error) {
|
||||||
manager, err := newNFTablesManager(parentCTX)
|
checkResult := checkfw.Check()
|
||||||
if err == nil {
|
switch checkResult {
|
||||||
log.Debugf("nftables firewall manager will be used")
|
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
|
||||||
return manager, nil
|
log.Debug("creating an iptables firewall manager for route rules")
|
||||||
|
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
|
||||||
|
return newIptablesManager(parentCTX, ipv6Supported)
|
||||||
|
case checkfw.NFTABLES:
|
||||||
|
log.Info("creating an nftables firewall manager for route rules")
|
||||||
|
return newNFTablesManager(parentCTX), nil
|
||||||
}
|
}
|
||||||
fMgr, err := newIptablesManager(parentCTX)
|
|
||||||
if err != nil {
|
return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules")
|
||||||
log.Debugf("failed to initialize iptables for root mgr: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
log.Debugf("iptables firewall manager will be used")
|
|
||||||
return fMgr, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInPair(pair routerPair) routerPair {
|
func getInPair(pair routerPair) routerPair {
|
||||||
|
@ -6,9 +6,10 @@ package routemanager
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall returns a nil manager
|
// newFirewall returns a nil manager
|
||||||
func NewFirewall(context.Context) (firewallManager, error) {
|
func newFirewall(context.Context) (firewallManager, error) {
|
||||||
return nil, fmt.Errorf("firewall not supported on this OS")
|
return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
@ -49,29 +49,28 @@ type iptablesManager struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newIptablesManager(parentCtx context.Context) (*iptablesManager, error) {
|
func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) {
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err)
|
||||||
} else if !isIptablesClientAvailable(ipv4Client) {
|
|
||||||
return nil, fmt.Errorf("iptables is missing for ipv4")
|
|
||||||
}
|
|
||||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to initialize iptables for ipv6: %s", err)
|
|
||||||
} else if !isIptablesClientAvailable(ipv6Client) {
|
|
||||||
log.Infof("iptables is missing for ipv6")
|
|
||||||
ipv6Client = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
return &iptablesManager{
|
manager := &iptablesManager{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
ipv4Client: ipv4Client,
|
ipv4Client: ipv4Client,
|
||||||
ipv6Client: ipv6Client,
|
|
||||||
rules: make(map[string]map[string][]string),
|
rules: make(map[string]map[string][]string),
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
if ipv6Supported {
|
||||||
|
manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
||||||
@ -486,8 +485,3 @@ func getIptablesRuleType(table string) string {
|
|||||||
}
|
}
|
||||||
return ruleType
|
return ruleType
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|
||||||
_, err := client.ListChains("filter")
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
@ -16,11 +16,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, _ := newIptablesManager(context.TODO())
|
manager, err := newIptablesManager(context.TODO(), true)
|
||||||
|
require.NoError(t, err, "should return a valid iptables manager")
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err := manager.RestoreOrCreateContainers()
|
err = manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")
|
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")
|
||||||
|
@ -36,7 +36,7 @@ type DefaultManager struct {
|
|||||||
|
|
||||||
// NewManager returns a new route manager
|
// NewManager returns a new route manager
|
||||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
|
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
|
||||||
serverRouter, err := newServerRouter(ctx, wgInterface)
|
srvRouter, err := newServerRouter(ctx, wgInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("server router is not supported: %s", err)
|
log.Errorf("server router is not supported: %s", err)
|
||||||
}
|
}
|
||||||
@ -46,7 +46,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
|||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
clientNetworks: make(map[string]*clientNetwork),
|
clientNetworks: make(map[string]*clientNetwork),
|
||||||
serverRouter: serverRouter,
|
serverRouter: srvRouter,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
pubKey: pubKey,
|
pubKey: pubKey,
|
||||||
|
@ -3,11 +3,12 @@ package routemanager
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/pion/transport/v2/stdnet"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v2/stdnet"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@ -30,6 +31,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
inputInitRoutes []*route.Route
|
inputInitRoutes []*route.Route
|
||||||
inputRoutes []*route.Route
|
inputRoutes []*route.Route
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
|
removeSrvRouter bool
|
||||||
serverRoutesExpected int
|
serverRoutesExpected int
|
||||||
clientNetworkWatchersExpected int
|
clientNetworkWatchersExpected int
|
||||||
}{
|
}{
|
||||||
@ -117,6 +119,35 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
serverRoutesExpected: 1,
|
serverRoutesExpected: 1,
|
||||||
clientNetworkWatchersExpected: 1,
|
clientNetworkWatchersExpected: 1,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router",
|
||||||
|
inputRoutes: []*route.Route{
|
||||||
|
{
|
||||||
|
ID: "a",
|
||||||
|
NetID: "routeA",
|
||||||
|
Peer: localPeerKey,
|
||||||
|
Network: netip.MustParsePrefix("100.64.30.250/30"),
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
Metric: 9999,
|
||||||
|
Masquerade: false,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "b",
|
||||||
|
NetID: "routeB",
|
||||||
|
Peer: remotePeerKey1,
|
||||||
|
Network: netip.MustParsePrefix("8.8.9.9/32"),
|
||||||
|
NetworkType: route.IPv4Network,
|
||||||
|
Metric: 9999,
|
||||||
|
Masquerade: false,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputSerial: 1,
|
||||||
|
removeSrvRouter: true,
|
||||||
|
serverRoutesExpected: 0,
|
||||||
|
clientNetworkWatchersExpected: 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Should Create 1 HA Route and 1 Standalone",
|
name: "Should Create 1 HA Route and 1 Standalone",
|
||||||
inputRoutes: []*route.Route{
|
inputRoutes: []*route.Route{
|
||||||
@ -385,6 +416,10 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||||
defer routeManager.Stop()
|
defer routeManager.Stop()
|
||||||
|
|
||||||
|
if testCase.removeSrvRouter {
|
||||||
|
routeManager.serverRouter = nil
|
||||||
|
}
|
||||||
|
|
||||||
if len(testCase.inputInitRoutes) > 0 {
|
if len(testCase.inputInitRoutes) > 0 {
|
||||||
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||||
require.NoError(t, err, "should update routes with init routes")
|
require.NoError(t, err, "should update routes with init routes")
|
||||||
@ -395,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
|
|
||||||
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
||||||
|
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
||||||
sr := routeManager.serverRouter.(*defaultServerRouter)
|
sr := routeManager.serverRouter.(*defaultServerRouter)
|
||||||
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
|
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||||
}
|
}
|
||||||
|
@ -86,10 +86,10 @@ type nftablesManager struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
func newNFTablesManager(parentCtx context.Context) *nftablesManager {
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
|
||||||
mgr := &nftablesManager{
|
return &nftablesManager{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
conn: &nftables.Conn{},
|
conn: &nftables.Conn{},
|
||||||
@ -97,18 +97,6 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
|||||||
rules: make(map[string]*nftables.Rule),
|
rules: make(map[string]*nftables.Rule),
|
||||||
defaultForwardRules: make([]*nftables.Rule, 2),
|
defaultForwardRules: make([]*nftables.Rule, 2),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := mgr.isSupported()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = mgr.readFilterTable()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return mgr, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanRoutingRules cleans existing nftables rules from the system
|
// CleanRoutingRules cleans existing nftables rules from the system
|
||||||
@ -147,6 +135,10 @@ func (n *nftablesManager) RestoreOrCreateContainers() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
|
if table.Name == "filter" {
|
||||||
|
n.filterTable = table
|
||||||
|
continue
|
||||||
|
}
|
||||||
if table.Name == nftablesTable {
|
if table.Name == nftablesTable {
|
||||||
if table.Family == nftables.TableFamilyIPv4 {
|
if table.Family == nftables.TableFamilyIPv4 {
|
||||||
n.tableIPv4 = table
|
n.tableIPv4 = table
|
||||||
@ -259,21 +251,6 @@ func (n *nftablesManager) refreshRulesMap() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nftablesManager) readFilterTable() error {
|
|
||||||
tables, err := n.conn.ListTables()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range tables {
|
|
||||||
if t.Name == "filter" {
|
|
||||||
n.filterTable = t
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *nftablesManager) eraseDefaultForwardRule() error {
|
func (n *nftablesManager) eraseDefaultForwardRule() error {
|
||||||
if n.defaultForwardRules[0] == nil {
|
if n.defaultForwardRules[0] == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -544,14 +521,6 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nftablesManager) isSupported() error {
|
|
||||||
_, err := n.conn.ListChains()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("nftables is not supported: %s", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPayloadDirectives get expression directives based on ip version and direction
|
// getPayloadDirectives get expression directives based on ip version and direction
|
||||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
||||||
switch {
|
switch {
|
||||||
|
@ -10,20 +10,23 @@ import (
|
|||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/checkfw"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||||
|
|
||||||
manager, err := newNFTablesManager(context.TODO())
|
if checkfw.Check() != checkfw.NFTABLES {
|
||||||
if err != nil {
|
t.Skip("nftables not supported on this OS")
|
||||||
t.Fatalf("failed to create nftables manager: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
manager := newNFTablesManager(context.TODO())
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
err := manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
||||||
@ -126,19 +129,19 @@ func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||||
|
if checkfw.Check() != checkfw.NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this OS")
|
||||||
|
}
|
||||||
|
|
||||||
for _, testCase := range insertRuleTestCases {
|
for _, testCase := range insertRuleTestCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
manager, err := newNFTablesManager(context.TODO())
|
manager := newNFTablesManager(context.TODO())
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create nftables manager: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
err := manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
err = manager.InsertRoutingRules(testCase.inputPair)
|
||||||
@ -226,19 +229,19 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||||
|
if checkfw.Check() != checkfw.NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this OS")
|
||||||
|
}
|
||||||
|
|
||||||
for _, testCase := range removeRuleTestCases {
|
for _, testCase := range removeRuleTestCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
manager, err := newNFTablesManager(context.TODO())
|
manager := newNFTablesManager(context.TODO())
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create nftables manager: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
defer manager.CleanRoutingRules()
|
defer manager.CleanRoutingRules()
|
||||||
|
|
||||||
err = manager.RestoreOrCreateContainers()
|
err := manager.RestoreOrCreateContainers()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
table := manager.tableIPv4
|
table := manager.tableIPv4
|
||||||
|
@ -22,7 +22,7 @@ type defaultServerRouter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) {
|
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) {
|
||||||
firewall, err := NewFirewall(ctx)
|
firewall, err := newFirewall(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user