mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 01:38:41 +02:00
Add incoming routing rules (#486)
add an income firewall rule for each routing pair the pair for the income rule has inverted source and destination
This commit is contained in:
parent
878ca6db22
commit
93d20e370b
@ -15,6 +15,8 @@ const (
|
|||||||
ipv4Nat = "netbird-rt-ipv4-nat"
|
ipv4Nat = "netbird-rt-ipv4-nat"
|
||||||
natFormat = "netbird-nat-%s"
|
natFormat = "netbird-nat-%s"
|
||||||
forwardingFormat = "netbird-fwd-%s"
|
forwardingFormat = "netbird-fwd-%s"
|
||||||
|
inNatFormat = "netbird-nat-in-%s"
|
||||||
|
inForwardingFormat = "netbird-fwd-in-%s"
|
||||||
ipv6 = "ipv6"
|
ipv6 = "ipv6"
|
||||||
ipv4 = "ipv4"
|
ipv4 = "ipv4"
|
||||||
)
|
)
|
||||||
@ -53,3 +55,13 @@ func NewFirewall(parentCTX context.Context) firewallManager {
|
|||||||
|
|
||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getInPair(pair routerPair) routerPair {
|
||||||
|
return routerPair{
|
||||||
|
ID: pair.ID,
|
||||||
|
// invert source/destination
|
||||||
|
source: pair.destination,
|
||||||
|
destination: pair.source,
|
||||||
|
masquerade: pair.masquerade,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -311,7 +311,37 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
|
|||||||
i.mux.Lock()
|
i.mux.Lock()
|
||||||
defer i.mux.Unlock()
|
defer i.mux.Unlock()
|
||||||
|
|
||||||
|
err := i.insertRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.insertRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, getInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pair.masquerade {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.insertRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.insertRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, getInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertRoutingRule inserts an iptable rule
|
||||||
|
func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string, pair routerPair) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
prefix := netip.MustParsePrefix(pair.source)
|
prefix := netip.MustParsePrefix(pair.source)
|
||||||
ipVersion := ipv4
|
ipVersion := ipv4
|
||||||
iptablesClient := i.ipv4Client
|
iptablesClient := i.ipv4Client
|
||||||
@ -320,43 +350,22 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
|
|||||||
ipVersion = ipv6
|
ipVersion = ipv6
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardRuleKey := genKey(forwardingFormat, pair.ID)
|
ruleKey := genKey(keyFormat, pair.ID)
|
||||||
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, pair.source, pair.destination)
|
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
|
||||||
existingRule, found := i.rules[ipVersion][forwardRuleKey]
|
existingRule, found := i.rules[ipVersion][ruleKey]
|
||||||
if found {
|
if found {
|
||||||
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
|
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
|
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
|
||||||
}
|
}
|
||||||
delete(i.rules[ipVersion], forwardRuleKey)
|
delete(i.rules[ipVersion], ruleKey)
|
||||||
}
|
}
|
||||||
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
|
err = iptablesClient.Insert(table, chain, 1, rule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while adding new forwarding rule for %s: %v", pair.destination, err)
|
return fmt.Errorf("iptables: error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
i.rules[ipVersion][forwardRuleKey] = forwardRule
|
i.rules[ipVersion][ruleKey] = rule
|
||||||
|
|
||||||
if !pair.masquerade {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
natRuleKey := genKey(natFormat, pair.ID)
|
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, pair.source, pair.destination)
|
|
||||||
existingRule, found = i.rules[ipVersion][natRuleKey]
|
|
||||||
if found {
|
|
||||||
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while removing existing nat rulefor %s: %v", pair.destination, err)
|
|
||||||
}
|
|
||||||
delete(i.rules[ipVersion], natRuleKey)
|
|
||||||
}
|
|
||||||
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while adding new nat rulefor %s: %v", pair.destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
i.rules[ipVersion][natRuleKey] = natRule
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -366,7 +375,37 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
i.mux.Lock()
|
i.mux.Lock()
|
||||||
defer i.mux.Unlock()
|
defer i.mux.Unlock()
|
||||||
|
|
||||||
|
err := i.removeRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.removeRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, getInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pair.masquerade {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.removeRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = i.removeRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, getInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRoutingRule removes an iptables rule
|
||||||
|
func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair routerPair) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
prefix := netip.MustParsePrefix(pair.source)
|
prefix := netip.MustParsePrefix(pair.source)
|
||||||
ipVersion := ipv4
|
ipVersion := ipv4
|
||||||
iptablesClient := i.ipv4Client
|
iptablesClient := i.ipv4Client
|
||||||
@ -375,29 +414,23 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
ipVersion = ipv6
|
ipVersion = ipv6
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardRuleKey := genKey(forwardingFormat, pair.ID)
|
ruleKey := genKey(keyFormat, pair.ID)
|
||||||
existingRule, found := i.rules[ipVersion][forwardRuleKey]
|
existingRule, found := i.rules[ipVersion][ruleKey]
|
||||||
if found {
|
if found {
|
||||||
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
|
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
|
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(i.rules[ipVersion], forwardRuleKey)
|
delete(i.rules[ipVersion], ruleKey)
|
||||||
|
|
||||||
if !pair.masquerade {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
natRuleKey := genKey(natFormat, pair.ID)
|
|
||||||
existingRule, found = i.rules[ipVersion][natRuleKey]
|
|
||||||
if found {
|
|
||||||
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("iptables: error while removing existing nat rule for %s: %v", pair.destination, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(i.rules[ipVersion], natRuleKey)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getIptablesRuleType(table string) string {
|
||||||
|
ruleType := "forwarding"
|
||||||
|
if table == iptablesNatTable {
|
||||||
|
ruleType = "nat"
|
||||||
|
}
|
||||||
|
return ruleType
|
||||||
|
}
|
||||||
|
@ -159,6 +159,17 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
require.True(t, found, "forwarding rule should exist in the manager map")
|
require.True(t, found, "forwarding rule should exist in the manager map")
|
||||||
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
|
||||||
|
|
||||||
|
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
||||||
|
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
|
require.True(t, exists, "income forwarding rule should exist")
|
||||||
|
|
||||||
|
foundRule, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
|
||||||
|
require.True(t, found, "income forwarding rule should exist in the manager map")
|
||||||
|
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
|
||||||
|
|
||||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
||||||
|
|
||||||
@ -172,7 +183,23 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
} else {
|
} else {
|
||||||
require.False(t, exists, "nat rule should not be created")
|
require.False(t, exists, "nat rule should not be created")
|
||||||
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
|
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
|
||||||
require.False(t, foundNat, "nat rule should exist in the map")
|
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||||
|
}
|
||||||
|
|
||||||
|
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
||||||
|
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
||||||
|
if testCase.inputPair.masquerade {
|
||||||
|
require.True(t, exists, "income nat rule should be created")
|
||||||
|
foundNatRule, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
|
||||||
|
require.True(t, foundNat, "income nat rule should exist in the map")
|
||||||
|
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
||||||
|
} else {
|
||||||
|
require.False(t, exists, "nat rule should not be created")
|
||||||
|
_, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
|
||||||
|
require.False(t, foundNat, "income nat rule should not exist in the map")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -213,12 +240,24 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
|
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
|
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
||||||
|
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
||||||
|
|
||||||
|
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, inForwardRule...)
|
||||||
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
||||||
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
|
||||||
|
|
||||||
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
|
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
|
||||||
require.NoError(t, err, "inserting rule should not return error")
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
|
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
||||||
|
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
|
||||||
|
|
||||||
|
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, inNatRule...)
|
||||||
|
require.NoError(t, err, "inserting rule should not return error")
|
||||||
|
|
||||||
delete(manager.rules, ipv4)
|
delete(manager.rules, ipv4)
|
||||||
delete(manager.rules, ipv6)
|
delete(manager.rules, ipv6)
|
||||||
|
|
||||||
@ -235,12 +274,26 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
|
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
|
||||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
require.False(t, found, "forwarding rule should exist in the manager map")
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||||
|
require.False(t, exists, "income forwarding rule should not exist")
|
||||||
|
|
||||||
|
_, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
|
||||||
|
require.False(t, found, "income forwarding rule should exist in the manager map")
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
|
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
||||||
require.False(t, exists, "nat rule should not exist")
|
require.False(t, exists, "nat rule should not exist")
|
||||||
|
|
||||||
_, found = manager.rules[testCase.ipVersion][natRuleKey]
|
_, found = manager.rules[testCase.ipVersion][natRuleKey]
|
||||||
require.False(t, found, "forwarding rule should exist in the manager map")
|
require.False(t, found, "nat rule should exist in the manager map")
|
||||||
|
|
||||||
|
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
|
||||||
|
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
|
||||||
|
require.False(t, exists, "income nat rule should not exist")
|
||||||
|
|
||||||
|
_, found = manager.rules[testCase.ipVersion][inNatRuleKey]
|
||||||
|
require.False(t, found, "income nat rule should exist in the manager map")
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,6 @@ import (
|
|||||||
)
|
)
|
||||||
import "github.com/google/nftables"
|
import "github.com/google/nftables"
|
||||||
|
|
||||||
//
|
|
||||||
const (
|
const (
|
||||||
nftablesTable = "netbird-rt"
|
nftablesTable = "netbird-rt"
|
||||||
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
nftablesRoutingForwardingChain = "netbird-rt-fwd"
|
||||||
@ -248,53 +247,77 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
|
|||||||
n.mux.Lock()
|
n.mux.Lock()
|
||||||
defer n.mux.Unlock()
|
defer n.mux.Unlock()
|
||||||
|
|
||||||
|
err := n.refreshRulesMap()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair.masquerade {
|
||||||
|
err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = n.conn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
||||||
|
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {
|
||||||
|
|
||||||
prefix := netip.MustParsePrefix(pair.source)
|
prefix := netip.MustParsePrefix(pair.source)
|
||||||
|
|
||||||
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
||||||
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
||||||
|
|
||||||
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
var expression []expr.Any
|
||||||
fwdKey := genKey(forwardingFormat, pair.ID)
|
if isNat {
|
||||||
if prefix.Addr().Unmap().Is4() {
|
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
||||||
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: n.tableIPv4,
|
|
||||||
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
|
|
||||||
Exprs: forwardExp,
|
|
||||||
UserData: []byte(fwdKey),
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
|
expression = append(sourceExp, append(destExp, exprCounterAccept...)...)
|
||||||
Table: n.tableIPv6,
|
|
||||||
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
|
|
||||||
Exprs: forwardExp,
|
|
||||||
UserData: []byte(fwdKey),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if pair.masquerade {
|
ruleKey := genKey(format, pair.ID)
|
||||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
|
||||||
natKey := genKey(natFormat, pair.ID)
|
|
||||||
|
|
||||||
if prefix.Addr().Unmap().Is4() {
|
_, exists := n.rules[ruleKey]
|
||||||
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
|
if exists {
|
||||||
Table: n.tableIPv4,
|
err := n.removeRoutingRule(format, pair)
|
||||||
Chain: n.chains[ipv4][nftablesRoutingNatChain],
|
|
||||||
Exprs: natExp,
|
|
||||||
UserData: []byte(natKey),
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: n.tableIPv6,
|
|
||||||
Chain: n.chains[ipv6][nftablesRoutingNatChain],
|
|
||||||
Exprs: natExp,
|
|
||||||
UserData: []byte(natKey),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := n.conn.Flush()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Unmap().Is4() {
|
||||||
|
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
|
||||||
|
Table: n.tableIPv4,
|
||||||
|
Chain: n.chains[ipv4][chain],
|
||||||
|
Exprs: expression,
|
||||||
|
UserData: []byte(ruleKey),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
|
||||||
|
Table: n.tableIPv6,
|
||||||
|
Chain: n.chains[ipv6][chain],
|
||||||
|
Exprs: expression,
|
||||||
|
UserData: []byte(ruleKey),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -309,26 +332,26 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fwdKey := genKey(forwardingFormat, pair.ID)
|
err = n.removeRoutingRule(forwardingFormat, pair)
|
||||||
natKey := genKey(natFormat, pair.ID)
|
|
||||||
fwdRule, found := n.rules[fwdKey]
|
|
||||||
if found {
|
|
||||||
err = n.conn.DelRule(fwdRule)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err)
|
return err
|
||||||
}
|
}
|
||||||
log.Debugf("nftables: removing forwarding rule for %s", pair.destination)
|
|
||||||
delete(n.rules, fwdKey)
|
err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
|
||||||
}
|
|
||||||
natRule, found := n.rules[natKey]
|
|
||||||
if found {
|
|
||||||
err = n.conn.DelRule(natRule)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err)
|
return err
|
||||||
}
|
}
|
||||||
log.Debugf("nftables: removing nat rule for %s", pair.destination)
|
|
||||||
delete(n.rules, natKey)
|
err = n.removeRoutingRule(natFormat, pair)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = n.removeRoutingRule(inNatFormat, getInPair(pair))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = n.conn.Flush()
|
err = n.conn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
|
||||||
@ -337,6 +360,29 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
|
||||||
|
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
|
||||||
|
ruleKey := genKey(format, pair.ID)
|
||||||
|
|
||||||
|
rule, found := n.rules[ruleKey]
|
||||||
|
if found {
|
||||||
|
ruleType := "forwarding"
|
||||||
|
if rule.Chain.Type == nftables.ChainTypeNAT {
|
||||||
|
ruleType = "nat"
|
||||||
|
}
|
||||||
|
|
||||||
|
err := n.conn.DelRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)
|
||||||
|
|
||||||
|
delete(n.rules, ruleKey)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// getPayloadDirectives get expression directives based on ip version and direction
|
// getPayloadDirectives get expression directives based on ip version and direction
|
||||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
||||||
switch {
|
switch {
|
||||||
|
@ -189,6 +189,45 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
|
||||||
|
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
|
||||||
|
testingExpression = append(sourceExp, destExp...)
|
||||||
|
inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
||||||
|
|
||||||
|
found = 0
|
||||||
|
for _, registeredChains := range manager.chains {
|
||||||
|
for _, chain := range registeredChains {
|
||||||
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
||||||
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
||||||
|
found = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
|
|
||||||
|
if testCase.inputPair.masquerade {
|
||||||
|
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
||||||
|
found := 0
|
||||||
|
for _, registeredChains := range manager.chains {
|
||||||
|
for _, chain := range registeredChains {
|
||||||
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
||||||
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
||||||
|
found = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -241,6 +280,28 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
UserData: []byte(natRuleKey),
|
UserData: []byte(natRuleKey),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
|
||||||
|
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
|
||||||
|
|
||||||
|
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...)
|
||||||
|
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
||||||
|
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||||
|
Table: table,
|
||||||
|
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
|
||||||
|
Exprs: forwardExp,
|
||||||
|
UserData: []byte(inForwardRuleKey),
|
||||||
|
})
|
||||||
|
|
||||||
|
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
||||||
|
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
||||||
|
|
||||||
|
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||||
|
Table: table,
|
||||||
|
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
|
||||||
|
Exprs: natExp,
|
||||||
|
UserData: []byte(inNatRuleKey),
|
||||||
|
})
|
||||||
|
|
||||||
err = nftablesTestingClient.Flush()
|
err = nftablesTestingClient.Flush()
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
@ -259,8 +320,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist")
|
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
||||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist")
|
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||||
|
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
||||||
|
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user