mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-01 20:43:43 +01:00
93d20e370b
add an income firewall rule for each routing pair the pair for the income rule has inverted source and destination
433 lines
11 KiB
Go
433 lines
11 KiB
Go
package routemanager
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/google/nftables/binaryutil"
|
|
"github.com/google/nftables/expr"
|
|
log "github.com/sirupsen/logrus"
|
|
"net"
|
|
"net/netip"
|
|
"sync"
|
|
)
|
|
import "github.com/google/nftables"
|
|
|
|
const (
|
|
nftablesTable = "netbird-rt"
|
|
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
|
nftablesRoutingNatChain = "netbird-rt-nat"
|
|
)
|
|
|
|
// constants needed to create nftable rules
|
|
const (
|
|
ipv4Len = 4
|
|
ipv4SrcOffset = 12
|
|
ipv4DestOffset = 16
|
|
ipv6Len = 16
|
|
ipv6SrcOffset = 8
|
|
ipv6DestOffset = 24
|
|
exprDirectionSource = "source"
|
|
exprDirectionDestination = "destination"
|
|
)
|
|
|
|
// some presets for building nftable rules
|
|
var (
|
|
zeroXor = binaryutil.NativeEndian.PutUint32(0)
|
|
|
|
zeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...)
|
|
|
|
exprAllowRelatedEstablished = []expr.Any{
|
|
&expr.Ct{
|
|
Register: 1,
|
|
SourceRegister: false,
|
|
Key: 0,
|
|
},
|
|
&expr.Bitwise{
|
|
DestRegister: 1,
|
|
SourceRegister: 1,
|
|
Len: 4,
|
|
Mask: []uint8{0x6, 0x0, 0x0, 0x0},
|
|
Xor: zeroXor,
|
|
},
|
|
&expr.Cmp{
|
|
Register: 1,
|
|
Data: binaryutil.NativeEndian.PutUint32(0),
|
|
},
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictAccept,
|
|
},
|
|
}
|
|
|
|
exprCounterAccept = []expr.Any{
|
|
&expr.Counter{},
|
|
&expr.Verdict{
|
|
Kind: expr.VerdictAccept,
|
|
},
|
|
}
|
|
)
|
|
|
|
type nftablesManager struct {
|
|
ctx context.Context
|
|
stop context.CancelFunc
|
|
conn *nftables.Conn
|
|
tableIPv4 *nftables.Table
|
|
tableIPv6 *nftables.Table
|
|
chains map[string]map[string]*nftables.Chain
|
|
rules map[string]*nftables.Rule
|
|
mux sync.Mutex
|
|
}
|
|
|
|
// CleanRoutingRules cleans existing nftables rules from the system
|
|
func (n *nftablesManager) CleanRoutingRules() {
|
|
n.mux.Lock()
|
|
defer n.mux.Unlock()
|
|
log.Debug("flushing tables")
|
|
if n.tableIPv4 != nil && n.tableIPv6 != nil {
|
|
n.conn.FlushTable(n.tableIPv6)
|
|
n.conn.FlushTable(n.tableIPv4)
|
|
}
|
|
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
|
|
}
|
|
|
|
// RestoreOrCreateContainers restores existing nftables containers (tables and chains)
|
|
// if they don't exist, we create them
|
|
func (n *nftablesManager) RestoreOrCreateContainers() error {
|
|
n.mux.Lock()
|
|
defer n.mux.Unlock()
|
|
|
|
if n.tableIPv6 != nil && n.tableIPv4 != nil {
|
|
log.Debugf("nftables: containers already restored, skipping")
|
|
return nil
|
|
}
|
|
|
|
tables, err := n.conn.ListTables()
|
|
if err != nil {
|
|
return fmt.Errorf("nftables: unable to list tables: %v", err)
|
|
}
|
|
|
|
for _, table := range tables {
|
|
if table.Name == nftablesTable {
|
|
if table.Family == nftables.TableFamilyIPv4 {
|
|
n.tableIPv4 = table
|
|
continue
|
|
}
|
|
n.tableIPv6 = table
|
|
}
|
|
}
|
|
|
|
if n.tableIPv4 == nil {
|
|
n.tableIPv4 = n.conn.AddTable(&nftables.Table{
|
|
Name: nftablesTable,
|
|
Family: nftables.TableFamilyIPv4,
|
|
})
|
|
}
|
|
|
|
if n.tableIPv6 == nil {
|
|
n.tableIPv6 = n.conn.AddTable(&nftables.Table{
|
|
Name: nftablesTable,
|
|
Family: nftables.TableFamilyIPv6,
|
|
})
|
|
}
|
|
|
|
chains, err := n.conn.ListChains()
|
|
if err != nil {
|
|
return fmt.Errorf("nftables: unable to list chains: %v", err)
|
|
}
|
|
|
|
n.chains[ipv4] = make(map[string]*nftables.Chain)
|
|
n.chains[ipv6] = make(map[string]*nftables.Chain)
|
|
|
|
for _, chain := range chains {
|
|
switch {
|
|
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4:
|
|
n.chains[ipv4][chain.Name] = chain
|
|
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6:
|
|
n.chains[ipv6][chain.Name] = chain
|
|
}
|
|
}
|
|
|
|
if _, found := n.chains[ipv4][nftablesRoutingForwardingChain]; !found {
|
|
n.chains[ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
|
|
Name: nftablesRoutingForwardingChain,
|
|
Table: n.tableIPv4,
|
|
Hooknum: nftables.ChainHookForward,
|
|
Priority: nftables.ChainPriorityNATDest + 1,
|
|
Type: nftables.ChainTypeFilter,
|
|
})
|
|
}
|
|
|
|
if _, found := n.chains[ipv4][nftablesRoutingNatChain]; !found {
|
|
n.chains[ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
|
|
Name: nftablesRoutingNatChain,
|
|
Table: n.tableIPv4,
|
|
Hooknum: nftables.ChainHookPostrouting,
|
|
Priority: nftables.ChainPriorityNATSource - 1,
|
|
Type: nftables.ChainTypeNAT,
|
|
})
|
|
}
|
|
|
|
if _, found := n.chains[ipv6][nftablesRoutingForwardingChain]; !found {
|
|
n.chains[ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
|
|
Name: nftablesRoutingForwardingChain,
|
|
Table: n.tableIPv6,
|
|
Hooknum: nftables.ChainHookForward,
|
|
Priority: nftables.ChainPriorityNATDest + 1,
|
|
Type: nftables.ChainTypeFilter,
|
|
})
|
|
}
|
|
|
|
if _, found := n.chains[ipv6][nftablesRoutingNatChain]; !found {
|
|
n.chains[ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
|
|
Name: nftablesRoutingNatChain,
|
|
Table: n.tableIPv6,
|
|
Hooknum: nftables.ChainHookPostrouting,
|
|
Priority: nftables.ChainPriorityNATSource - 1,
|
|
Type: nftables.ChainTypeNAT,
|
|
})
|
|
}
|
|
|
|
err = n.refreshRulesMap()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
n.checkOrCreateDefaultForwardingRules()
|
|
err = n.conn.Flush()
|
|
if err != nil {
|
|
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
|
// duplicates and to get missing attributes that we don't have when adding new rules
|
|
func (n *nftablesManager) refreshRulesMap() error {
|
|
for _, registeredChains := range n.chains {
|
|
for _, chain := range registeredChains {
|
|
rules, err := n.conn.GetRules(chain.Table, chain)
|
|
if err != nil {
|
|
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
|
}
|
|
for _, rule := range rules {
|
|
if len(rule.UserData) > 0 {
|
|
n.rules[string(rule.UserData)] = rule
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
|
|
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
|
|
_, foundIPv4 := n.rules[ipv4Forwarding]
|
|
if !foundIPv4 {
|
|
n.rules[ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{
|
|
Table: n.tableIPv4,
|
|
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
|
|
Exprs: exprAllowRelatedEstablished,
|
|
UserData: []byte(ipv4Forwarding),
|
|
})
|
|
}
|
|
|
|
_, foundIPv6 := n.rules[ipv6Forwarding]
|
|
if !foundIPv6 {
|
|
n.rules[ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{
|
|
Table: n.tableIPv6,
|
|
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
|
|
Exprs: exprAllowRelatedEstablished,
|
|
UserData: []byte(ipv6Forwarding),
|
|
})
|
|
}
|
|
}
|
|
|
|
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
|
func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
|
|
n.mux.Lock()
|
|
defer n.mux.Unlock()
|
|
|
|
err := n.refreshRulesMap()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if pair.masquerade {
|
|
err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
err = n.conn.Flush()
|
|
if err != nil {
|
|
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
|
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {
|
|
|
|
prefix := netip.MustParsePrefix(pair.source)
|
|
|
|
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
|
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
|
|
|
var expression []expr.Any
|
|
if isNat {
|
|
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
|
} else {
|
|
expression = append(sourceExp, append(destExp, exprCounterAccept...)...)
|
|
}
|
|
|
|
ruleKey := genKey(format, pair.ID)
|
|
|
|
_, exists := n.rules[ruleKey]
|
|
if exists {
|
|
err := n.removeRoutingRule(format, pair)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if prefix.Addr().Unmap().Is4() {
|
|
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
|
|
Table: n.tableIPv4,
|
|
Chain: n.chains[ipv4][chain],
|
|
Exprs: expression,
|
|
UserData: []byte(ruleKey),
|
|
})
|
|
} else {
|
|
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
|
|
Table: n.tableIPv6,
|
|
Chain: n.chains[ipv6][chain],
|
|
Exprs: expression,
|
|
UserData: []byte(ruleKey),
|
|
})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
|
|
func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
|
n.mux.Lock()
|
|
defer n.mux.Unlock()
|
|
|
|
err := n.refreshRulesMap()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = n.removeRoutingRule(forwardingFormat, pair)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = n.removeRoutingRule(natFormat, pair)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = n.removeRoutingRule(inNatFormat, getInPair(pair))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = n.conn.Flush()
|
|
if err != nil {
|
|
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
|
}
|
|
log.Debugf("nftables: removed rules for %s", pair.destination)
|
|
return nil
|
|
}
|
|
|
|
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
|
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
|
|
ruleKey := genKey(format, pair.ID)
|
|
|
|
rule, found := n.rules[ruleKey]
|
|
if found {
|
|
ruleType := "forwarding"
|
|
if rule.Chain.Type == nftables.ChainTypeNAT {
|
|
ruleType = "nat"
|
|
}
|
|
|
|
err := n.conn.DelRule(rule)
|
|
if err != nil {
|
|
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
|
|
}
|
|
|
|
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)
|
|
|
|
delete(n.rules, ruleKey)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getPayloadDirectives get expression directives based on ip version and direction
|
|
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
|
switch {
|
|
case direction == exprDirectionSource && isIPv4:
|
|
return ipv4SrcOffset, ipv4Len, zeroXor
|
|
case direction == exprDirectionDestination && isIPv4:
|
|
return ipv4DestOffset, ipv4Len, zeroXor
|
|
case direction == exprDirectionSource && isIPv6:
|
|
return ipv6SrcOffset, ipv6Len, zeroXor6
|
|
case direction == exprDirectionDestination && isIPv6:
|
|
return ipv6DestOffset, ipv6Len, zeroXor6
|
|
default:
|
|
panic("no matched payload directive")
|
|
}
|
|
}
|
|
|
|
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
|
func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any {
|
|
ip, network, _ := net.ParseCIDR(cidr)
|
|
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
add := ipToAdd.Unmap()
|
|
|
|
offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6())
|
|
|
|
return []expr.Any{
|
|
// fetch src add
|
|
&expr.Payload{
|
|
DestRegister: 1,
|
|
Base: expr.PayloadBaseNetworkHeader,
|
|
Offset: offSet,
|
|
Len: packetLen,
|
|
},
|
|
// net mask
|
|
&expr.Bitwise{
|
|
DestRegister: 1,
|
|
SourceRegister: 1,
|
|
Len: packetLen,
|
|
Mask: network.Mask,
|
|
Xor: zeroXor,
|
|
},
|
|
// net address
|
|
&expr.Cmp{
|
|
Register: 1,
|
|
Data: add.AsSlice(),
|
|
},
|
|
}
|
|
}
|