mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-12 08:58:44 +01:00
93d20e370b
add an income firewall rule for each routing pair the pair for the income rule has inverted source and destination
334 lines
12 KiB
Go
334 lines
12 KiB
Go
package routemanager
|
|
|
|
import (
|
|
"context"
|
|
"github.com/google/nftables"
|
|
"github.com/google/nftables/expr"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"testing"
|
|
)
|
|
|
|
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|
|
|
ctx, cancel := context.WithCancel(context.TODO())
|
|
|
|
manager := &nftablesManager{
|
|
ctx: ctx,
|
|
stop: cancel,
|
|
conn: &nftables.Conn{},
|
|
chains: make(map[string]map[string]*nftables.Chain),
|
|
rules: make(map[string]*nftables.Rule),
|
|
}
|
|
|
|
nftablesTestingClient := &nftables.Conn{}
|
|
|
|
defer manager.CleanRoutingRules()
|
|
|
|
err := manager.RestoreOrCreateContainers()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
|
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
|
|
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
|
|
require.Len(t, manager.rules, 2, "should have created rules for ipv4 and ipv6")
|
|
|
|
pair := routerPair{
|
|
ID: "abc",
|
|
source: "100.100.100.1/32",
|
|
destination: "100.100.100.0/24",
|
|
masquerade: true,
|
|
}
|
|
|
|
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
|
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
|
|
|
forward4Exp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
|
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
|
inserted4Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: manager.tableIPv4,
|
|
Chain: manager.chains[ipv4][nftablesRoutingForwardingChain],
|
|
Exprs: forward4Exp,
|
|
UserData: []byte(forward4RuleKey),
|
|
})
|
|
|
|
nat4Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
|
nat4RuleKey := genKey(natFormat, pair.ID)
|
|
|
|
inserted4Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: manager.tableIPv4,
|
|
Chain: manager.chains[ipv4][nftablesRoutingNatChain],
|
|
Exprs: nat4Exp,
|
|
UserData: []byte(nat4RuleKey),
|
|
})
|
|
|
|
err = nftablesTestingClient.Flush()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
pair = routerPair{
|
|
ID: "xyz",
|
|
source: "fc00::1/128",
|
|
destination: "fc11::/64",
|
|
masquerade: true,
|
|
}
|
|
|
|
sourceExp = generateCIDRMatcherExpressions("source", pair.source)
|
|
destExp = generateCIDRMatcherExpressions("destination", pair.destination)
|
|
|
|
forward6Exp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
|
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
|
inserted6Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: manager.tableIPv6,
|
|
Chain: manager.chains[ipv6][nftablesRoutingForwardingChain],
|
|
Exprs: forward6Exp,
|
|
UserData: []byte(forward6RuleKey),
|
|
})
|
|
|
|
nat6Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
|
nat6RuleKey := genKey(natFormat, pair.ID)
|
|
|
|
inserted6Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: manager.tableIPv6,
|
|
Chain: manager.chains[ipv6][nftablesRoutingNatChain],
|
|
Exprs: nat6Exp,
|
|
UserData: []byte(nat6RuleKey),
|
|
})
|
|
|
|
err = nftablesTestingClient.Flush()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
manager.tableIPv4 = nil
|
|
manager.tableIPv6 = nil
|
|
|
|
err = manager.RestoreOrCreateContainers()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
|
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
|
|
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
|
|
require.Len(t, manager.rules, 6, "should have restored all rules for ipv4 and ipv6")
|
|
|
|
foundRule, found := manager.rules[forward4RuleKey]
|
|
require.True(t, found, "forwarding rule should exist in the map")
|
|
assert.Equal(t, inserted4Forwarding.Exprs, foundRule.Exprs, "stored forwarding rule expressions should match")
|
|
|
|
foundRule, found = manager.rules[nat4RuleKey]
|
|
require.True(t, found, "nat rule should exist in the map")
|
|
// match len of output as nftables client doesn't return expressions with masquerade expression
|
|
assert.ElementsMatch(t, inserted4Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule expressions should match")
|
|
|
|
foundRule, found = manager.rules[forward6RuleKey]
|
|
require.True(t, found, "forwarding rule should exist in the map")
|
|
assert.Equal(t, inserted6Forwarding.Exprs, foundRule.Exprs, "stored forward rule should match")
|
|
|
|
foundRule, found = manager.rules[nat6RuleKey]
|
|
require.True(t, found, "nat rule should exist in the map")
|
|
// match len of output as nftables client doesn't return expressions with masquerade expression
|
|
assert.ElementsMatch(t, inserted6Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule should match")
|
|
}
|
|
|
|
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|
|
|
for _, testCase := range insertRuleTestCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.TODO())
|
|
|
|
manager := &nftablesManager{
|
|
ctx: ctx,
|
|
stop: cancel,
|
|
conn: &nftables.Conn{},
|
|
chains: make(map[string]map[string]*nftables.Chain),
|
|
rules: make(map[string]*nftables.Rule),
|
|
}
|
|
|
|
nftablesTestingClient := &nftables.Conn{}
|
|
|
|
defer manager.CleanRoutingRules()
|
|
|
|
err := manager.RestoreOrCreateContainers()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
err = manager.InsertRoutingRules(testCase.inputPair)
|
|
require.NoError(t, err, "forwarding pair should be inserted")
|
|
|
|
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
|
|
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
|
|
testingExpression := append(sourceExp, destExp...)
|
|
fwdRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
|
|
|
found := 0
|
|
for _, registeredChains := range manager.chains {
|
|
for _, chain := range registeredChains {
|
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
for _, rule := range rules {
|
|
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
|
found = 1
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
|
|
if testCase.inputPair.masquerade {
|
|
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
|
found := 0
|
|
for _, registeredChains := range manager.chains {
|
|
for _, chain := range registeredChains {
|
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
for _, rule := range rules {
|
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
|
found = 1
|
|
}
|
|
}
|
|
}
|
|
}
|
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
}
|
|
|
|
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
|
|
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
|
|
testingExpression = append(sourceExp, destExp...)
|
|
inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
|
|
|
found = 0
|
|
for _, registeredChains := range manager.chains {
|
|
for _, chain := range registeredChains {
|
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
for _, rule := range rules {
|
|
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
|
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
|
|
found = 1
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
|
|
if testCase.inputPair.masquerade {
|
|
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
|
found := 0
|
|
for _, registeredChains := range manager.chains {
|
|
for _, chain := range registeredChains {
|
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
for _, rule := range rules {
|
|
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
|
found = 1
|
|
}
|
|
}
|
|
}
|
|
}
|
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|
|
|
for _, testCase := range removeRuleTestCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.TODO())
|
|
|
|
manager := &nftablesManager{
|
|
ctx: ctx,
|
|
stop: cancel,
|
|
conn: &nftables.Conn{},
|
|
chains: make(map[string]map[string]*nftables.Chain),
|
|
rules: make(map[string]*nftables.Rule),
|
|
}
|
|
|
|
nftablesTestingClient := &nftables.Conn{}
|
|
|
|
defer manager.CleanRoutingRules()
|
|
|
|
err := manager.RestoreOrCreateContainers()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
table := manager.tableIPv4
|
|
if testCase.ipVersion == ipv6 {
|
|
table = manager.tableIPv6
|
|
}
|
|
|
|
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
|
|
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
|
|
|
|
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
|
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
|
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: table,
|
|
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
|
|
Exprs: forwardExp,
|
|
UserData: []byte(forwardRuleKey),
|
|
})
|
|
|
|
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
|
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
|
|
|
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: table,
|
|
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
|
|
Exprs: natExp,
|
|
UserData: []byte(natRuleKey),
|
|
})
|
|
|
|
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
|
|
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
|
|
|
|
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...)
|
|
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
|
|
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: table,
|
|
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
|
|
Exprs: forwardExp,
|
|
UserData: []byte(inForwardRuleKey),
|
|
})
|
|
|
|
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
|
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
|
|
|
|
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: table,
|
|
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
|
|
Exprs: natExp,
|
|
UserData: []byte(inNatRuleKey),
|
|
})
|
|
|
|
err = nftablesTestingClient.Flush()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
manager.tableIPv4 = nil
|
|
manager.tableIPv6 = nil
|
|
|
|
err = manager.RestoreOrCreateContainers()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
err = manager.RemoveRoutingRules(testCase.inputPair)
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
for _, registeredChains := range manager.chains {
|
|
for _, chain := range registeredChains {
|
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
for _, rule := range rules {
|
|
if len(rule.UserData) > 0 {
|
|
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
|
|
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
|
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
|
|
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|