package nftables import ( "bytes" "encoding/binary" "fmt" "net" "net/netip" "strings" "sync" "github.com/google/nftables" "github.com/google/nftables/expr" "github.com/google/uuid" "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/iface" ) const ( // FilterTableName is the name of the table that is used for filtering by the Netbird client FilterTableName = "netbird-acl" // FilterInputChainName is the name of the chain that is used for filtering incoming packets FilterInputChainName = "netbird-acl-input-filter" // FilterOutputChainName is the name of the chain that is used for filtering outgoing packets FilterOutputChainName = "netbird-acl-output-filter" ) // Manager of iptables firewall type Manager struct { mutex sync.Mutex conn *nftables.Conn tableIPv4 *nftables.Table tableIPv6 *nftables.Table filterInputChainIPv4 *nftables.Chain filterOutputChainIPv4 *nftables.Chain filterInputChainIPv6 *nftables.Chain filterOutputChainIPv6 *nftables.Chain wgIface iFaceMapper } // iFaceMapper defines subset methods of interface required for manager type iFaceMapper interface { Name() string Address() iface.WGAddress } // Create nftables firewall manager func Create(wgIface iFaceMapper) (*Manager, error) { m := &Manager{ conn: &nftables.Conn{}, wgIface: wgIface, } if err := m.Reset(); err != nil { return nil, err } return m, nil } // AddFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule func (m *Manager) AddFiltering( ip net.IP, proto fw.Protocol, sPort *fw.Port, dPort *fw.Port, direction fw.RuleDirection, action fw.Action, comment string, ) (fw.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() var ( err error table *nftables.Table chain *nftables.Chain ) if direction == fw.RuleDirectionOUT { table, chain, err = m.chain( ip, FilterOutputChainName, nftables.ChainHookOutput, nftables.ChainPriorityFilter, nftables.ChainTypeFilter) } else { table, chain, err = m.chain( ip, FilterInputChainName, nftables.ChainHookInput, nftables.ChainPriorityFilter, nftables.ChainTypeFilter) } if err != nil { return nil, err } ifaceKey := expr.MetaKeyIIFNAME if direction == fw.RuleDirectionOUT { ifaceKey = expr.MetaKeyOIFNAME } expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: ifname(m.wgIface.Name()), }, } if proto != "all" { expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: uint32(9), Len: uint32(1), }) var protoData []byte switch proto { case fw.ProtocolTCP: protoData = []byte{unix.IPPROTO_TCP} case fw.ProtocolUDP: protoData = []byte{unix.IPPROTO_UDP} case fw.ProtocolICMP: protoData = []byte{unix.IPPROTO_ICMP} default: return nil, fmt.Errorf("unsupported protocol: %s", proto) } expressions = append(expressions, &expr.Cmp{ Register: 1, Op: expr.CmpOpEq, Data: protoData, }) } // don't use IP matching if IP is ip 0.0.0.0 if s := ip.String(); s != "0.0.0.0" && s != "::" { // source address position var adrLen, adrOffset uint32 if ip.To4() == nil { adrLen = 16 adrOffset = 8 } else { adrLen = 4 adrOffset = 12 } // change to destination address position if need if direction == fw.RuleDirectionOUT { adrOffset += adrLen } ipToAdd, _ := netip.AddrFromSlice(ip) add := ipToAdd.Unmap() expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: adrOffset, Len: adrLen, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: add.AsSlice(), }, ) } if sPort != nil && len(sPort.Values) != 0 { expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: encodePort(*sPort), }, ) } if dPort != nil && len(dPort.Values) != 0 { expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: encodePort(*dPort), }, ) } if action == fw.ActionAccept { expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept}) } else { expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) } id := uuid.New().String() userData := []byte(strings.Join([]string{id, comment}, " ")) _ = m.conn.InsertRule(&nftables.Rule{ Table: table, Chain: chain, Position: 0, Exprs: expressions, UserData: userData, }) if err := m.conn.Flush(); err != nil { return nil, err } list, err := m.conn.GetRules(table, chain) if err != nil { return nil, err } // Add the rule to the chain rule := &Rule{id: id} for _, r := range list { if bytes.Equal(r.UserData, userData) { rule.Rule = r break } } if rule.Rule == nil { return nil, fmt.Errorf("rule not found") } return rule, nil } // chain returns the chain for the given IP address with specific settings func (m *Manager) chain( ip net.IP, name string, hook nftables.ChainHook, priority nftables.ChainPriority, cType nftables.ChainType, ) (*nftables.Table, *nftables.Chain, error) { var err error getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) { if c != nil { return c, nil } return m.createChainIfNotExists(tf, name, hook, priority, cType) } if ip.To4() != nil { if name == FilterInputChainName { m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4) return m.tableIPv4, m.filterInputChainIPv4, err } m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4) return m.tableIPv4, m.filterOutputChainIPv4, err } if name == FilterInputChainName { m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6) return m.tableIPv4, m.filterInputChainIPv6, err } m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6) return m.tableIPv4, m.filterOutputChainIPv6, err } // table returns the table for the given family of the IP address func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) { if family == nftables.TableFamilyIPv4 { if m.tableIPv4 != nil { return m.tableIPv4, nil } table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4) if err != nil { return nil, err } m.tableIPv4 = table return m.tableIPv4, nil } if m.tableIPv6 != nil { return m.tableIPv6, nil } table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6) if err != nil { return nil, err } m.tableIPv6 = table return m.tableIPv6, nil } func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) { tables, err := m.conn.ListTablesOfFamily(family) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } for _, t := range tables { if t.Name == FilterTableName { return t, nil } } return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil } func (m *Manager) createChainIfNotExists( family nftables.TableFamily, name string, hooknum nftables.ChainHook, priority nftables.ChainPriority, chainType nftables.ChainType, ) (*nftables.Chain, error) { table, err := m.table(family) if err != nil { return nil, err } chains, err := m.conn.ListChainsOfTableFamily(family) if err != nil { return nil, fmt.Errorf("list of chains: %w", err) } for _, c := range chains { if c.Name == name && c.Table.Name == table.Name { return c, nil } } polAccept := nftables.ChainPolicyAccept chain := &nftables.Chain{ Name: name, Table: table, Hooknum: hooknum, Priority: priority, Type: chainType, Policy: &polAccept, } chain = m.conn.AddChain(chain) ifaceKey := expr.MetaKeyIIFNAME shiftDSTAddr := 0 if name == FilterOutputChainName { ifaceKey = expr.MetaKeyOIFNAME shiftDSTAddr = 1 } expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: ifname(m.wgIface.Name()), }, } mask, _ := netip.AddrFromSlice(m.wgIface.Address().Network.Mask) if m.wgIface.Address().IP.To4() == nil { ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To16()) expressions = append(expressions, &expr.Payload{ DestRegister: 2, Base: expr.PayloadBaseNetworkHeader, Offset: uint32(8 + (16 * shiftDSTAddr)), Len: 16, }, &expr.Bitwise{ SourceRegister: 2, DestRegister: 2, Len: 16, Xor: []byte{0x0, 0x0, 0x0, 0x0}, Mask: mask.Unmap().AsSlice(), }, &expr.Cmp{ Op: expr.CmpOpNeq, Register: 2, Data: ip.Unmap().AsSlice(), }, &expr.Verdict{Kind: expr.VerdictAccept}, ) } else { ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) expressions = append(expressions, &expr.Payload{ DestRegister: 2, Base: expr.PayloadBaseNetworkHeader, Offset: uint32(12 + (4 * shiftDSTAddr)), Len: 4, }, &expr.Bitwise{ SourceRegister: 2, DestRegister: 2, Len: 4, Xor: []byte{0x0, 0x0, 0x0, 0x0}, Mask: m.wgIface.Address().Network.Mask, }, &expr.Cmp{ Op: expr.CmpOpNeq, Register: 2, Data: ip.Unmap().AsSlice(), }, &expr.Verdict{Kind: expr.VerdictAccept}, ) } _ = m.conn.AddRule(&nftables.Rule{ Table: table, Chain: chain, Exprs: expressions, }) expressions = []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: ifname(m.wgIface.Name()), }, &expr.Verdict{Kind: expr.VerdictDrop}, } _ = m.conn.AddRule(&nftables.Rule{ Table: table, Chain: chain, Exprs: expressions, }) if err := m.conn.Flush(); err != nil { return nil, err } return chain, nil } // DeleteRule from the firewall by rule definition func (m *Manager) DeleteRule(rule fw.Rule) error { nativeRule, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") } if err := m.conn.DelRule(nativeRule.Rule); err != nil { return err } return m.conn.Flush() } // Reset firewall to the default state func (m *Manager) Reset() error { m.mutex.Lock() defer m.mutex.Unlock() chains, err := m.conn.ListChains() if err != nil { return fmt.Errorf("list of chains: %w", err) } for _, c := range chains { if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { m.conn.DelChain(c) } } tables, err := m.conn.ListTables() if err != nil { return fmt.Errorf("list of tables: %w", err) } for _, t := range tables { if t.Name == FilterTableName { m.conn.DelTable(t) } } return m.conn.Flush() } func encodePort(port fw.Port) []byte { bs := make([]byte, 2) binary.BigEndian.PutUint16(bs, uint16(port.Values[0])) return bs } func ifname(n string) []byte { b := make([]byte, 16) copy(b, []byte(n+"\x00")) return b }