[client] Use the prerouting chain to mark for masquerading to support older systems (#2808)

This commit is contained in:
Viktor Liu
2024-11-07 12:37:04 +01:00
committed by GitHub
parent 3e88b7c56e
commit 509e184e10
8 changed files with 472 additions and 287 deletions

View File

@@ -18,22 +18,24 @@ import (
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
ipv4Nat = "netbird-rt-nat"
nbnet "github.com/netbirdio/netbird/util/net"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set"
)
@@ -323,24 +325,25 @@ func (r *router) Reset() error {
}
func (r *router) cleanUpDefaultForwardRules() error {
err := r.cleanJumpRules()
if err != nil {
return err
if err := r.cleanJumpRules(); err != nil {
return fmt.Errorf("clean jump rules: %w", err)
}
log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} {
table := r.getTableForChain(chain)
ok, err := r.iptablesClient.ChainExists(table, chain)
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
log.Errorf("failed check chain %s, error: %v", chain, err)
return err
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} else if ok {
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
if err != nil {
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
return err
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
}
@@ -349,9 +352,16 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %w", chain, err)
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
@@ -359,6 +369,10 @@ func (r *router) createContainers() error {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
@@ -366,6 +380,32 @@ func (r *router) createContainers() error {
return nil
}
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %v", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %v", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
@@ -377,10 +417,14 @@ func (r *router) createAndSetupChain(chain string) error {
}
func (r *router) getTableForChain(chain string) string {
if chain == chainRTNAT {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
return tableFilter
}
func (r *router) insertEstablishedRule(chain string) error {
@@ -398,25 +442,39 @@ func (r *router) insertEstablishedRule(chain string) error {
}
func (r *router) addJumpRules() error {
rule := []string{"-j", chainRTNAT}
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
// Jump to NAT chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err)
}
r.rules[ipv4Nat] = rule
r.rules[jumpNat] = natRule
// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
return nil
}
func (r *router) cleanJumpRules() error {
rule, found := r.rules[ipv4Nat]
if found {
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
for _, ruleKey := range []string{jumpNat, jumpPre} {
if rule, exists := r.rules[ruleKey]; exists {
table := tableNat
chain := chainPOSTROUTING
if ruleKey == jumpPre {
table = tableMangle
chain = chainPREROUTING
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
}
delete(r.rules, ruleKey)
}
}
return nil
}
@@ -424,19 +482,35 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
}
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
}
r.rules[ruleKey] = rule
return nil
}
@@ -444,13 +518,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("nat rule %s not found", ruleKey)
log.Debugf("marking rule %s not found", ruleKey)
}
return nil
@@ -482,16 +555,6 @@ func (r *router) updateState() {
}
}
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
lointdir := "-o"
if inverse {
intdir = "-o"
lointdir = "-i"
}
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string