diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 5cd69245b..1c0527ebc 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -352,14 +352,14 @@ func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialOptionalEntries() { m.optionalEntries["FORWARD"] = []entry{ { - spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, + spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules}, position: 2, }, } m.optionalEntries["PREROUTING"] = []entry{ { - spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, + spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)}, position: 1, }, } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 9b75640b4..16e2e97f4 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -18,19 +18,19 @@ import ( "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/statemanager" -) - -const ( - ipv4Nat = "netbird-rt-nat" + nbnet "github.com/netbirdio/netbird/util/net" ) // constants needed to manage and create iptable rules const ( tableFilter = "filter" tableNat = "nat" + tableMangle = "mangle" chainPOSTROUTING = "POSTROUTING" + chainPREROUTING = "PREROUTING" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWD = "NETBIRD-RT-FWD" + chainRTPRE = "NETBIRD-RT-PRE" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" @@ -323,24 +323,25 @@ func (r *router) Reset() error { } func (r *router) cleanUpDefaultForwardRules() error { - err := r.cleanJumpRules() - if err != nil { - return err + if err := r.cleanJumpRules(); err != nil { + return fmt.Errorf("clean jump rules: %w", err) } log.Debug("flushing routing related tables") - for _, chain := range []string{chainRTFWD, chainRTNAT} { - table := r.getTableForChain(chain) - - ok, err := r.iptablesClient.ChainExists(table, chain) + for _, chainInfo := range []struct { + chain string + table string + }{ + {chainRTFWD, tableFilter}, + {chainRTNAT, tableNat}, + {chainRTPRE, tableMangle}, + } { + ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) if err != nil { - log.Errorf("failed check chain %s, error: %v", chain, err) - return err + return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } else if ok { - err = r.iptablesClient.ClearAndDeleteChain(table, chain) - if err != nil { - log.Errorf("failed cleaning chain %s, error: %v", chain, err) - return err + if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil { + return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } } } @@ -349,9 +350,16 @@ func (r *router) cleanUpDefaultForwardRules() error { } func (r *router) createContainers() error { - for _, chain := range []string{chainRTFWD, chainRTNAT} { - if err := r.createAndSetupChain(chain); err != nil { - return fmt.Errorf("create chain %s: %w", chain, err) + for _, chainInfo := range []struct { + chain string + table string + }{ + {chainRTFWD, tableFilter}, + {chainRTPRE, tableMangle}, + {chainRTNAT, tableNat}, + } { + if err := r.createAndSetupChain(chainInfo.chain); err != nil { + return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) } } @@ -359,6 +367,10 @@ func (r *router) createContainers() error { return fmt.Errorf("insert established rule: %w", err) } + if err := r.addPostroutingRules(); err != nil { + return fmt.Errorf("add static nat rules: %w", err) + } + if err := r.addJumpRules(); err != nil { return fmt.Errorf("add jump rules: %w", err) } @@ -366,6 +378,32 @@ func (r *router) createContainers() error { return nil } +func (r *router) addPostroutingRules() error { + // First rule for outbound masquerade + rule1 := []string{ + "-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade), + "!", "-o", "lo", + "-j", routingFinalNatJump, + } + if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil { + return fmt.Errorf("add outbound masquerade rule: %v", err) + } + r.rules["static-nat-outbound"] = rule1 + + // Second rule for return traffic masquerade + rule2 := []string{ + "-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn), + "-o", r.wgIface.Name(), + "-j", routingFinalNatJump, + } + if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil { + return fmt.Errorf("add return masquerade rule: %v", err) + } + r.rules["static-nat-return"] = rule2 + + return nil +} + func (r *router) createAndSetupChain(chain string) error { table := r.getTableForChain(chain) @@ -377,10 +415,14 @@ func (r *router) createAndSetupChain(chain string) error { } func (r *router) getTableForChain(chain string) string { - if chain == chainRTNAT { + switch chain { + case chainRTNAT: return tableNat + case chainRTPRE: + return tableMangle + default: + return tableFilter } - return tableFilter } func (r *router) insertEstablishedRule(chain string) error { @@ -398,25 +440,39 @@ func (r *router) insertEstablishedRule(chain string) error { } func (r *router) addJumpRules() error { - rule := []string{"-j", chainRTNAT} - err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) - if err != nil { - return err + // Jump to NAT chain + natRule := []string{"-j", chainRTNAT} + if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil { + return fmt.Errorf("add nat jump rule: %v", err) } - r.rules[ipv4Nat] = rule + r.rules["jump-nat"] = natRule + + // Jump to prerouting chain + preRule := []string{"-j", chainRTPRE} + if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil { + return fmt.Errorf("add prerouting jump rule: %v", err) + } + r.rules["jump-pre"] = preRule return nil } func (r *router) cleanJumpRules() error { - rule, found := r.rules[ipv4Nat] - if found { - err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) - if err != nil { - return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) + for _, ruleKey := range []string{"jump-nat", "jump-pre"} { + if rule, exists := r.rules[ruleKey]; exists { + table := tableNat + chain := chainPOSTROUTING + if ruleKey == "jump-pre" { + table = tableMangle + chain = chainPREROUTING + } + + if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil { + return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err) + } + delete(r.rules, ruleKey) } } - return nil } @@ -424,19 +480,35 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { ruleKey := firewall.GenKey(firewall.NatFormat, pair) if rule, exists := r.rules[ruleKey]; exists { - if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { - return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) + if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil { + return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err) } delete(r.rules, ruleKey) } - rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse) - if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil { - return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) + markValue := nbnet.PreroutingFwmarkMasquerade + if pair.Inverse { + markValue = nbnet.PreroutingFwmarkMasqueradeReturn + } + + rule := []string{"-i", r.wgIface.Name()} + if pair.Inverse { + rule = []string{"!", "-i", r.wgIface.Name()} + } + + rule = append(rule, + "-m", "conntrack", + "--ctstate", "NEW", + "-s", pair.Source.String(), + "-d", pair.Destination.String(), + "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue), + ) + + if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil { + return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err) } r.rules[ruleKey] = rule - return nil } @@ -444,13 +516,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { ruleKey := firewall.GenKey(firewall.NatFormat, pair) if rule, exists := r.rules[ruleKey]; exists { - if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { - return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err) + if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil { + return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err) } - delete(r.rules, ruleKey) } else { - log.Debugf("nat rule %s not found", ruleKey) + log.Debugf("marking rule %s not found", ruleKey) } return nil @@ -482,16 +553,6 @@ func (r *router) updateState() { } } -func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { - intdir := "-i" - lointdir := "-o" - if inverse { - intdir = "-o" - lointdir = "-i" - } - return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump} -} - func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { var rule []string diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 2d821a9db..861bf8601 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -3,17 +3,18 @@ package iptables import ( + "fmt" "net/netip" "os/exec" "testing" "github.com/coreos/go-iptables/iptables" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + nbnet "github.com/netbirdio/netbird/util/net" ) func isIptablesSupported() bool { @@ -34,14 +35,24 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.NoError(t, manager.init(nil)) defer func() { - _ = manager.Reset() + assert.NoError(t, manager.Reset(), "shouldn't return error") }() - require.Len(t, manager.rules, 2, "should have created rules map") + // Now 5 rules: + // 1. established rule in forward chain + // 2. jump rule to NAT chain + // 3. jump rule to PRE chain + // 4. static outbound masquerade rule + // 5. static return masquerade rule + require.Len(t, manager.rules, 5, "should have created rules map") - exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) + exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) - require.True(t, exists, "postrouting rule should exist") + require.True(t, exists, "postrouting jump rule should exist") + + exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING) + require.True(t, exists, "prerouting jump rule should exist") pair := firewall.RouterPair{ ID: "abc", @@ -49,22 +60,15 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { Destination: netip.MustParsePrefix("100.100.100.0/24"), Masquerade: true, } - forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} - err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) - require.NoError(t, err, "inserting rule should not return error") - - nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false) - - err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...) - require.NoError(t, err, "inserting rule should not return error") + err = manager.AddNatRule(pair) + require.NoError(t, err, "adding NAT rule should not return error") err = manager.Reset() require.NoError(t, err, "shouldn't return error") } func TestIptablesManager_AddNatRule(t *testing.T) { - if !isIptablesSupported() { t.SkipNow() } @@ -79,52 +83,66 @@ func TestIptablesManager_AddNatRule(t *testing.T) { require.NoError(t, manager.init(nil)) defer func() { - err := manager.Reset() - if err != nil { - log.Errorf("failed to reset iptables manager: %s", err) - } + assert.NoError(t, manager.Reset(), "shouldn't return error") }() err = manager.AddNatRule(testCase.InputPair) - require.NoError(t, err, "forwarding pair should be inserted") + require.NoError(t, err, "marking rule should be inserted") natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) - - exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) - if testCase.InputPair.Masquerade { - require.True(t, exists, "nat rule should be created") - foundNatRule, foundNat := manager.rules[natRuleKey] - require.True(t, foundNat, "nat rule should exist in the map") - require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match") - } else { - require.False(t, exists, "nat rule should not be created") - _, foundNat := manager.rules[natRuleKey] - require.False(t, foundNat, "nat rule should not exist in the map") + markingRule := []string{ + "-i", ifaceMock.Name(), + "-m", "conntrack", + "--ctstate", "NEW", + "-s", testCase.InputPair.Source.String(), + "-d", testCase.InputPair.Destination.String(), + "-j", "MARK", "--set-mark", + fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade), } - inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) - - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) + exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) if testCase.InputPair.Masquerade { - require.True(t, exists, "income nat rule should be created") - foundNatRule, foundNat := manager.rules[inNatRuleKey] - require.True(t, foundNat, "income nat rule should exist in the map") - require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match") + require.True(t, exists, "marking rule should be created") + foundRule, found := manager.rules[natRuleKey] + require.True(t, found, "marking rule should exist in the map") + require.Equal(t, markingRule, foundRule, "stored marking rule should match") } else { - require.False(t, exists, "nat rule should not be created") - _, foundNat := manager.rules[inNatRuleKey] - require.False(t, foundNat, "income nat rule should not exist in the map") + require.False(t, exists, "marking rule should not be created") + _, found := manager.rules[natRuleKey] + require.False(t, found, "marking rule should not exist in the map") + } + + // Check inverse rule + inversePair := firewall.GetInversePair(testCase.InputPair) + inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair) + inverseMarkingRule := []string{ + "!", "-i", ifaceMock.Name(), + "-m", "conntrack", + "--ctstate", "NEW", + "-s", inversePair.Source.String(), + "-d", inversePair.Destination.String(), + "-j", "MARK", "--set-mark", + fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn), + } + + exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) + if testCase.InputPair.Masquerade { + require.True(t, exists, "inverse marking rule should be created") + foundRule, found := manager.rules[inverseRuleKey] + require.True(t, found, "inverse marking rule should exist in the map") + require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match") + } else { + require.False(t, exists, "inverse marking rule should not be created") + _, found := manager.rules[inverseRuleKey] + require.False(t, found, "inverse marking rule should not exist in the map") } }) } } func TestIptablesManager_RemoveNatRule(t *testing.T) { - if !isIptablesSupported() { t.SkipNow() } @@ -137,42 +155,52 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) defer func() { - _ = manager.Reset() + assert.NoError(t, manager.Reset(), "shouldn't return error") }() - require.NoError(t, err, "shouldn't return error") - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) - - err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...) - require.NoError(t, err, "inserting rule should not return error") - - inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) - - err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...) - require.NoError(t, err, "inserting rule should not return error") - - err = manager.Reset() - require.NoError(t, err, "shouldn't return error") + err = manager.AddNatRule(testCase.InputPair) + require.NoError(t, err, "should add NAT rule without error") err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") - exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) - require.False(t, exists, "nat rule should not exist") + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + markingRule := []string{ + "-i", ifaceMock.Name(), + "-m", "conntrack", + "--ctstate", "NEW", + "-s", testCase.InputPair.Source.String(), + "-d", testCase.InputPair.Destination.String(), + "-j", "MARK", "--set-mark", + fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade), + } + + exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) + require.False(t, exists, "marking rule should not exist") _, found := manager.rules[natRuleKey] - require.False(t, found, "nat rule should exist in the manager map") + require.False(t, found, "marking rule should not exist in the manager map") - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) - require.False(t, exists, "income nat rule should not exist") + // Check inverse rule removal + inversePair := firewall.GetInversePair(testCase.InputPair) + inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair) + inverseMarkingRule := []string{ + "!", "-i", ifaceMock.Name(), + "-m", "conntrack", + "--ctstate", "NEW", + "-s", inversePair.Source.String(), + "-d", inversePair.Destination.String(), + "-j", "MARK", "--set-mark", + fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn), + } - _, found = manager.rules[inNatRuleKey] - require.False(t, found, "income nat rule should exist in the manager map") + exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE) + require.False(t, exists, "inverse marking rule should not exist") + + _, found = manager.rules[inverseRuleKey] + require.False(t, found, "inverse marking rule should not exist in the map") }) } } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 2a40cd9f6..9391b47ec 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -17,6 +17,7 @@ import ( const ( ForwardingFormatPrefix = "netbird-fwd-" ForwardingFormat = "netbird-fwd-%s-%t" + PreroutingFormat = "netbird-prerouting-%s-%t" NatFormat = "netbird-nat-%s-%t" ) diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index ca7b2e59f..abe890fb9 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -520,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) { }, &expr.Immediate{ Register: 1, - Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected), }, &expr.Meta{ Key: expr.MetaKeyMARK, @@ -543,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, - Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected), }, &expr.Verdict{ Kind: expr.VerdictJump, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 0e7ea71b7..34bc9a9bc 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -21,6 +21,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -124,7 +125,6 @@ func (r *router) createContainers() error { insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) prio := *nftables.ChainPriorityNATSource - 1 - r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, Table: r.workTable, @@ -133,6 +133,21 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeNAT, }) + // Chain is created by acl manager + // TODO: move creation to a common place + r.chains[chainNamePrerouting] = &nftables.Chain{ + Name: chainNamePrerouting, + Table: r.workTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityMangle, + } + + // Add the single NAT rule that matches on mark + if err := r.addPostroutingRules(); err != nil { + return fmt.Errorf("add single nat rule: %v", err) + } + if err := r.acceptForwardRules(); err != nil { log.Errorf("failed to add accept rules for the forward chain: %s", err) } @@ -422,59 +437,149 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { sourceExp := generateCIDRMatcherExpressions(true, pair.Source) destExp := generateCIDRMatcherExpressions(false, pair.Destination) - dir := expr.MetaKeyIIFNAME - notDir := expr.MetaKeyOIFNAME + op := expr.CmpOpEq if pair.Inverse { - dir = expr.MetaKeyOIFNAME - notDir = expr.MetaKeyIIFNAME + op = expr.CmpOpNeq } - lo := ifname("lo") - intf := ifname(r.wgIface.Name()) - exprs := []expr.Any{ - &expr.Meta{ - Key: dir, + // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. + // Masquerading will take care of the conntrack state, which means we won't need to mark established connections. + &expr.Ct{ + Key: expr.CtKeySTATE, Register: 1, }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, - - // We need to exclude the loopback interface as this changes the ebpf proxy port - &expr.Meta{ - Key: notDir, - Register: 1, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW), + Xor: binaryutil.NativeEndian.PutUint32(0), }, &expr.Cmp{ Op: expr.CmpOpNeq, Register: 1, - Data: lo, + Data: []byte{0, 0, 0, 0}, + }, + + // interface matching + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: op, + Register: 1, + Data: ifname(r.wgIface.Name()), }, } exprs = append(exprs, sourceExp...) exprs = append(exprs, destExp...) + + var markValue uint32 = nbnet.PreroutingFwmarkMasquerade + if pair.Inverse { + markValue = nbnet.PreroutingFwmarkMasqueradeReturn + } + exprs = append(exprs, - &expr.Counter{}, &expr.Masq{}, + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(markValue), + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + SourceRegister: true, + Register: 1, + }, ) - ruleKey := firewall.GenKey(firewall.NatFormat, pair) + ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) if _, exists := r.rules[ruleKey]; exists { if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove routing rule: %w", err) + return fmt.Errorf("remove prerouting rule: %w", err) } } r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ Table: r.workTable, - Chain: r.chains[chainNameRoutingNat], + Chain: r.chains[chainNamePrerouting], Exprs: exprs, UserData: []byte(ruleKey), }) + + return nil +} + +// addPostroutingRules adds the masquerade rules +func (r *router) addPostroutingRules() error { + // First masquerade rule for traffic coming in from WireGuard interface + exprs := []expr.Any{ + // Match on the first fwmark + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade), + }, + + // We need to exclude the loopback interface as this changes the ebpf proxy port + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: ifname("lo"), + }, + &expr.Counter{}, + &expr.Masq{}, + } + + r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: exprs, + }) + + // Second masquerade rule for traffic going out through WireGuard interface + exprs2 := []expr.Any{ + // Match on the second fwmark + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn), + }, + + // Match WireGuard interface + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Counter{}, + &expr.Masq{}, + } + + r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: exprs2, + }) + return nil } @@ -723,18 +828,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error return nberrors.FormatErrorOrNil(merr) } -// RemoveNatRule removes a nftables rule pair from nat chains +// RemoveNatRule removes the prerouting mark rule func (r *router) RemoveNatRule(pair firewall.RouterPair) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } if err := r.removeNatRule(pair); err != nil { - return fmt.Errorf("remove nat rule: %w", err) + return fmt.Errorf("remove prerouting rule: %w", err) } if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { - return fmt.Errorf("remove inverse nat rule: %w", err) + return fmt.Errorf("remove inverse prerouting rule: %w", err) } if err := r.removeLegacyRouteRule(pair); err != nil { @@ -749,21 +854,20 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return nil } -// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map func (r *router) removeNatRule(pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(firewall.NatFormat, pair) + ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) if rule, exists := r.rules[ruleKey]; exists { err := r.conn.DelRule(rule) if err != nil { - return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err) + return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err) } - log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination) + log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination) delete(r.rules, ruleKey) } else { - log.Debugf("nftables: nat rule %s not found", ruleKey) + log.Debugf("nftables: prerouting rule %s not found", ruleKey) } return nil diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 19ed48991..a5b9106a5 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -10,6 +10,7 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,100 +33,86 @@ func TestNftablesManager_AddNatRule(t *testing.T) { t.Skip("nftables not supported on this OS") } - table, err := createWorkTable() - require.NoError(t, err, "Failed to create work table") - - defer deleteWorkTable() - for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(table, ifaceMock) - require.NoError(t, err, "failed to create router") - require.NoError(t, manager.init(table)) + // need fw manager to init both acl mgr and router for all chains to be present + manager, err := Create(ifaceMock) + t.Cleanup(func() { + require.NoError(t, manager.Reset(nil)) + }) + require.NoError(t, err) + require.NoError(t, manager.Init(nil)) nftablesTestingClient := &nftables.Conn{} - defer func(manager *router) { - require.NoError(t, manager.Reset(), "failed to reset rules") - }(manager) - - require.NoError(t, err, "shouldn't return error") - - err = manager.AddNatRule(testCase.InputPair) + rtr := manager.router + err = rtr.AddNatRule(testCase.InputPair) require.NoError(t, err, "pair should be inserted") - defer func(manager *router, pair firewall.RouterPair) { - require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule") - }(manager, testCase.InputPair) + t.Cleanup(func() { + require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule") + }) if testCase.InputPair.Masquerade { - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - testingExpression := append(sourceExp, destExp...) //nolint:gocritic - testingExpression = append(testingExpression, - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + // Build expected expressions for connection tracking + conntrackExprs := []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + } + + // Build interface matching expression + ifaceExprs := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: ifname(ifaceMock.Name()), }, - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: ifname("lo"), - }, - ) - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) - found := 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match") - found = 1 - } - } } - require.Equal(t, 1, found, "should find at least 1 rule to test") - } - if testCase.InputPair.Masquerade { + // Build CIDR matching expressions sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - testingExpression := append(sourceExp, destExp...) //nolint:gocritic - testingExpression = append(testingExpression, - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(ifaceMock.Name()), - }, - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: ifname("lo"), - }, - ) - inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + // Combine all expressions in the correct order + testingExpression := append(conntrackExprs, ifaceExprs...) + testingExpression = append(testingExpression, sourceExp...) + testingExpression = append(testingExpression, destExp...) + + natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) found := 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match") - found = 1 + for _, chain := range rtr.chains { + if chain.Name == chainNamePrerouting { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { + // Compare expressions up to the mark setting expressions + require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match") + found = 1 + } } } } - require.Equal(t, 1, found, "should find at least 1 rule to test") + require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain") } - }) } } @@ -135,68 +122,66 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { t.Skip("nftables not supported on this OS") } - table, err := createWorkTable() - require.NoError(t, err, "Failed to create work table") - - defer deleteWorkTable() - for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(table, ifaceMock) - require.NoError(t, err, "failed to create router") - require.NoError(t, manager.init(table)) - - nftablesTestingClient := &nftables.Conn{} - - defer func(manager *router) { - require.NoError(t, manager.Reset(), "failed to reset rules") - }(manager) - - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - - natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) - - insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRoutingNat], - Exprs: natExp, - UserData: []byte(natRuleKey), + manager, err := Create(ifaceMock) + t.Cleanup(func() { + require.NoError(t, manager.Reset(nil)) }) + require.NoError(t, err) + require.NoError(t, manager.Init(nil)) - sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) - destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination) + rtr := manager.router - natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + // First add the NAT rule using the router's method + err = rtr.AddNatRule(testCase.InputPair) + require.NoError(t, err, "should add NAT rule") - insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRoutingNat], - Exprs: natExp, - UserData: []byte(inNatRuleKey), - }) - - err = nftablesTestingClient.Flush() - require.NoError(t, err, "shouldn't return error") - - err = manager.Reset() - require.NoError(t, err, "shouldn't return error") - - err = manager.RemoveNatRule(testCase.InputPair) - require.NoError(t, err, "shouldn't return error") - - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 { - require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") - require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") - } + // Verify the rule was added + natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) + found := false + rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) + require.NoError(t, err, "should list rules") + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { + found = true + break } } + require.True(t, found, "NAT rule should exist before removal") + + // Now remove the rule + err = rtr.RemoveNatRule(testCase.InputPair) + require.NoError(t, err, "shouldn't return error when removing rule") + + // Verify the rule was removed + found = false + rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) + require.NoError(t, err, "should list rules after removal") + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { + found = true + break + } + } + require.False(t, found, "NAT rule should not exist after removal") + + // Verify the static postrouting rules still exist + rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat]) + require.NoError(t, err, "should list postrouting rules") + foundCounter := false + for _, rule := range rules { + for _, e := range rule.Exprs { + if _, ok := e.(*expr.Counter); ok { + foundCounter = true + break + } + } + if foundCounter { + break + } + } + require.True(t, foundCounter, "static postrouting rule should remain") }) } } diff --git a/util/net/net.go b/util/net/net.go index 035d7552b..5448eb85a 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -11,8 +11,11 @@ import ( const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 - PreroutingFwmark = 0x1BD01 + NetbirdFwmark = 0x1BD00 + + PreroutingFwmarkRedirected = 0x1BD01 + PreroutingFwmarkMasquerade = 0x1BD11 + PreroutingFwmarkMasqueradeReturn = 0x1BD12 envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" )