mirror of
synced 2025-02-02 19:39:17 +01:00
Modify rules in iptables and nftables to accept all traffic not from netbird network but routed through it.
512 lines
11 KiB
512 lines
11 KiB
package nftables
import (
fw "github.com/netbirdio/netbird/client/firewall"
const (
// FilterTableName is the name of the table that is used for filtering by the Netbird client
FilterTableName = "netbird-acl"
// FilterInputChainName is the name of the chain that is used for filtering incoming packets
FilterInputChainName = "netbird-acl-input-filter"
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
FilterOutputChainName = "netbird-acl-output-filter"
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
conn *nftables.Conn
tableIPv4 *nftables.Table
tableIPv6 *nftables.Table
filterInputChainIPv4 *nftables.Chain
filterOutputChainIPv4 *nftables.Chain
filterInputChainIPv6 *nftables.Chain
filterOutputChainIPv6 *nftables.Chain
wgIface iFaceMapper
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
// Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
m := &Manager{
conn: &nftables.Conn{},
wgIface: wgIface,
if err := m.Reset(); err != nil {
return nil, err
return m, nil
// AddFiltering rule to the firewall
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddFiltering(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
comment string,
) (fw.Rule, error) {
defer m.mutex.Unlock()
var (
err error
table *nftables.Table
chain *nftables.Chain
if direction == fw.RuleDirectionOUT {
table, chain, err = m.chain(
} else {
table, chain, err = m.chain(
if err != nil {
return nil, err
ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
if proto != "all" {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Len: uint32(1),
var protoData []byte
switch proto {
case fw.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case fw.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case fw.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
return nil, fmt.Errorf("unsupported protocol: %s", proto)
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: protoData,
// don't use IP matching if IP is ip
if s := ip.String(); s != "" && s != "::" {
// source address position
var adrLen, adrOffset uint32
if ip.To4() == nil {
adrLen = 16
adrOffset = 8
} else {
adrLen = 4
adrOffset = 12
// change to destination address position if need
if direction == fw.RuleDirectionOUT {
adrOffset += adrLen
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expressions = append(expressions,
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: adrOffset,
Len: adrLen,
Op: expr.CmpOpEq,
Register: 1,
Data: add.AsSlice(),
if sPort != nil && len(sPort.Values) != 0 {
expressions = append(expressions,
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 0,
Len: 2,
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*sPort),
if dPort != nil && len(dPort.Values) != 0 {
expressions = append(expressions,
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*dPort),
if action == fw.ActionAccept {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
} else {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
id := uuid.New().String()
userData := []byte(strings.Join([]string{id, comment}, " "))
_ = m.conn.InsertRule(&nftables.Rule{
Table: table,
Chain: chain,
Position: 0,
Exprs: expressions,
UserData: userData,
if err := m.conn.Flush(); err != nil {
return nil, err
list, err := m.conn.GetRules(table, chain)
if err != nil {
return nil, err
// Add the rule to the chain
rule := &Rule{id: id}
for _, r := range list {
if bytes.Equal(r.UserData, userData) {
rule.Rule = r
if rule.Rule == nil {
return nil, fmt.Errorf("rule not found")
return rule, nil
// chain returns the chain for the given IP address with specific settings
func (m *Manager) chain(
ip net.IP,
name string,
hook nftables.ChainHook,
priority nftables.ChainPriority,
cType nftables.ChainType,
) (*nftables.Table, *nftables.Chain, error) {
var err error
getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) {
if c != nil {
return c, nil
return m.createChainIfNotExists(tf, name, hook, priority, cType)
if ip.To4() != nil {
if name == FilterInputChainName {
m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4)
return m.tableIPv4, m.filterInputChainIPv4, err
m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4)
return m.tableIPv4, m.filterOutputChainIPv4, err
if name == FilterInputChainName {
m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6)
return m.tableIPv4, m.filterInputChainIPv6, err
m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6)
return m.tableIPv4, m.filterOutputChainIPv6, err
// table returns the table for the given family of the IP address
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
if family == nftables.TableFamilyIPv4 {
if m.tableIPv4 != nil {
return m.tableIPv4, nil
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
if err != nil {
return nil, err
m.tableIPv4 = table
return m.tableIPv4, nil
if m.tableIPv6 != nil {
return m.tableIPv6, nil
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
if err != nil {
return nil, err
m.tableIPv6 = table
return m.tableIPv6, nil
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.conn.ListTablesOfFamily(family)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
for _, t := range tables {
if t.Name == FilterTableName {
return t, nil
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil
func (m *Manager) createChainIfNotExists(
family nftables.TableFamily,
name string,
hooknum nftables.ChainHook,
priority nftables.ChainPriority,
chainType nftables.ChainType,
) (*nftables.Chain, error) {
table, err := m.table(family)
if err != nil {
return nil, err
chains, err := m.conn.ListChainsOfTableFamily(family)
if err != nil {
return nil, fmt.Errorf("list of chains: %w", err)
for _, c := range chains {
if c.Name == name && c.Table.Name == table.Name {
return c, nil
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: table,
Hooknum: hooknum,
Priority: priority,
Type: chainType,
Policy: &polAccept,
chain = m.conn.AddChain(chain)
ifaceKey := expr.MetaKeyIIFNAME
shiftDSTAddr := 0
if name == FilterOutputChainName {
ifaceKey = expr.MetaKeyOIFNAME
shiftDSTAddr = 1
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
mask, _ := netip.AddrFromSlice(m.wgIface.Address().Network.Mask)
if m.wgIface.Address().IP.To4() == nil {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To16())
expressions = append(expressions,
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(8 + (16 * shiftDSTAddr)),
Len: 16,
SourceRegister: 2,
DestRegister: 2,
Len: 16,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: mask.Unmap().AsSlice(),
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
&expr.Verdict{Kind: expr.VerdictAccept},
} else {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions = append(expressions,
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(12 + (4 * shiftDSTAddr)),
Len: 4,
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
&expr.Verdict{Kind: expr.VerdictAccept},
_ = m.conn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expressions,
expressions = []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
&expr.Verdict{Kind: expr.VerdictDrop},
_ = m.conn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expressions,
if err := m.conn.Flush(); err != nil {
return nil, err
return chain, nil
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error {
nativeRule, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
if err := m.conn.DelRule(nativeRule.Rule); err != nil {
return err
return m.conn.Flush()
// Reset firewall to the default state
func (m *Manager) Reset() error {
defer m.mutex.Unlock()
chains, err := m.conn.ListChains()
if err != nil {
return fmt.Errorf("list of chains: %w", err)
for _, c := range chains {
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
tables, err := m.conn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
for _, t := range tables {
if t.Name == FilterTableName {
return m.conn.Flush()
func encodePort(port fw.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
return bs
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, []byte(n+"\x00"))
return b