mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-01 20:43:43 +01:00
1012172f04
Handle routes updates from management Manage routing firewall rules Manage peer RIB table Add get peer and get notification channel from the status recorder Update interface peers allowed IPs
271 lines
9.5 KiB
Go
271 lines
9.5 KiB
Go
package routemanager
|
|
|
|
import (
|
|
"context"
|
|
"github.com/google/nftables"
|
|
"github.com/google/nftables/expr"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"testing"
|
|
)
|
|
|
|
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|
|
|
ctx, cancel := context.WithCancel(context.TODO())
|
|
|
|
manager := &nftablesManager{
|
|
ctx: ctx,
|
|
stop: cancel,
|
|
conn: &nftables.Conn{},
|
|
chains: make(map[string]map[string]*nftables.Chain),
|
|
rules: make(map[string]*nftables.Rule),
|
|
}
|
|
|
|
nftablesTestingClient := &nftables.Conn{}
|
|
|
|
defer manager.CleanRoutingRules()
|
|
|
|
err := manager.RestoreOrCreateContainers()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
|
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
|
|
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
|
|
require.Len(t, manager.rules, 2, "should have created rules for ipv4 and ipv6")
|
|
|
|
pair := routerPair{
|
|
ID: "abc",
|
|
source: "100.100.100.1/32",
|
|
destination: "100.100.100.0/24",
|
|
masquerade: true,
|
|
}
|
|
|
|
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
|
|
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
|
|
|
|
forward4Exp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
|
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
|
inserted4Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: manager.tableIPv4,
|
|
Chain: manager.chains[ipv4][nftablesRoutingForwardingChain],
|
|
Exprs: forward4Exp,
|
|
UserData: []byte(forward4RuleKey),
|
|
})
|
|
|
|
nat4Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
|
nat4RuleKey := genKey(natFormat, pair.ID)
|
|
|
|
inserted4Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: manager.tableIPv4,
|
|
Chain: manager.chains[ipv4][nftablesRoutingNatChain],
|
|
Exprs: nat4Exp,
|
|
UserData: []byte(nat4RuleKey),
|
|
})
|
|
|
|
err = nftablesTestingClient.Flush()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
pair = routerPair{
|
|
ID: "xyz",
|
|
source: "fc00::1/128",
|
|
destination: "fc11::/64",
|
|
masquerade: true,
|
|
}
|
|
|
|
sourceExp = generateCIDRMatcherExpressions("source", pair.source)
|
|
destExp = generateCIDRMatcherExpressions("destination", pair.destination)
|
|
|
|
forward6Exp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
|
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
|
inserted6Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: manager.tableIPv6,
|
|
Chain: manager.chains[ipv6][nftablesRoutingForwardingChain],
|
|
Exprs: forward6Exp,
|
|
UserData: []byte(forward6RuleKey),
|
|
})
|
|
|
|
nat6Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
|
nat6RuleKey := genKey(natFormat, pair.ID)
|
|
|
|
inserted6Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: manager.tableIPv6,
|
|
Chain: manager.chains[ipv6][nftablesRoutingNatChain],
|
|
Exprs: nat6Exp,
|
|
UserData: []byte(nat6RuleKey),
|
|
})
|
|
|
|
err = nftablesTestingClient.Flush()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
manager.tableIPv4 = nil
|
|
manager.tableIPv6 = nil
|
|
|
|
err = manager.RestoreOrCreateContainers()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
|
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
|
|
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
|
|
require.Len(t, manager.rules, 6, "should have restored all rules for ipv4 and ipv6")
|
|
|
|
foundRule, found := manager.rules[forward4RuleKey]
|
|
require.True(t, found, "forwarding rule should exist in the map")
|
|
assert.Equal(t, inserted4Forwarding.Exprs, foundRule.Exprs, "stored forwarding rule expressions should match")
|
|
|
|
foundRule, found = manager.rules[nat4RuleKey]
|
|
require.True(t, found, "nat rule should exist in the map")
|
|
// match len of output as nftables client doesn't return expressions with masquerade expression
|
|
assert.ElementsMatch(t, inserted4Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule expressions should match")
|
|
|
|
foundRule, found = manager.rules[forward6RuleKey]
|
|
require.True(t, found, "forwarding rule should exist in the map")
|
|
assert.Equal(t, inserted6Forwarding.Exprs, foundRule.Exprs, "stored forward rule should match")
|
|
|
|
foundRule, found = manager.rules[nat6RuleKey]
|
|
require.True(t, found, "nat rule should exist in the map")
|
|
// match len of output as nftables client doesn't return expressions with masquerade expression
|
|
assert.ElementsMatch(t, inserted6Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule should match")
|
|
}
|
|
|
|
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
|
|
|
for _, testCase := range insertRuleTestCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.TODO())
|
|
|
|
manager := &nftablesManager{
|
|
ctx: ctx,
|
|
stop: cancel,
|
|
conn: &nftables.Conn{},
|
|
chains: make(map[string]map[string]*nftables.Chain),
|
|
rules: make(map[string]*nftables.Rule),
|
|
}
|
|
|
|
nftablesTestingClient := &nftables.Conn{}
|
|
|
|
defer manager.CleanRoutingRules()
|
|
|
|
err := manager.RestoreOrCreateContainers()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
err = manager.InsertRoutingRules(testCase.inputPair)
|
|
require.NoError(t, err, "forwarding pair should be inserted")
|
|
|
|
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
|
|
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
|
|
testingExpression := append(sourceExp, destExp...)
|
|
fwdRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
|
|
|
found := 0
|
|
for _, registeredChains := range manager.chains {
|
|
for _, chain := range registeredChains {
|
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
for _, rule := range rules {
|
|
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
|
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
|
|
found = 1
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
|
|
if testCase.inputPair.masquerade {
|
|
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
|
found := 0
|
|
for _, registeredChains := range manager.chains {
|
|
for _, chain := range registeredChains {
|
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
for _, rule := range rules {
|
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
|
found = 1
|
|
}
|
|
}
|
|
}
|
|
}
|
|
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
|
|
|
for _, testCase := range removeRuleTestCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.TODO())
|
|
|
|
manager := &nftablesManager{
|
|
ctx: ctx,
|
|
stop: cancel,
|
|
conn: &nftables.Conn{},
|
|
chains: make(map[string]map[string]*nftables.Chain),
|
|
rules: make(map[string]*nftables.Rule),
|
|
}
|
|
|
|
nftablesTestingClient := &nftables.Conn{}
|
|
|
|
defer manager.CleanRoutingRules()
|
|
|
|
err := manager.RestoreOrCreateContainers()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
table := manager.tableIPv4
|
|
if testCase.ipVersion == ipv6 {
|
|
table = manager.tableIPv6
|
|
}
|
|
|
|
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
|
|
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
|
|
|
|
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...)
|
|
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
|
|
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: table,
|
|
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
|
|
Exprs: forwardExp,
|
|
UserData: []byte(forwardRuleKey),
|
|
})
|
|
|
|
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
|
|
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
|
|
|
|
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
|
Table: table,
|
|
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
|
|
Exprs: natExp,
|
|
UserData: []byte(natRuleKey),
|
|
})
|
|
|
|
err = nftablesTestingClient.Flush()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
manager.tableIPv4 = nil
|
|
manager.tableIPv6 = nil
|
|
|
|
err = manager.RestoreOrCreateContainers()
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
err = manager.RemoveRoutingRules(testCase.inputPair)
|
|
require.NoError(t, err, "shouldn't return error")
|
|
|
|
for _, registeredChains := range manager.chains {
|
|
for _, chain := range registeredChains {
|
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
|
for _, rule := range rules {
|
|
if len(rule.UserData) > 0 {
|
|
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist")
|
|
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|