diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index eae9f7e25..869b0b359 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -38,10 +38,12 @@ const ( routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" - jumpManglePre = "jump-mangle-pre" - jumpNatPre = "jump-nat-pre" - jumpNatPost = "jump-nat-post" - matchSet = "--match-set" + jumpManglePre = "jump-mangle-pre" + jumpNatPre = "jump-nat-pre" + jumpNatPost = "jump-nat-post" + markManglePre = "mark-mangle-pre" + markManglePost = "mark-mangle-post" + matchSet = "--match-set" dnatSuffix = "_dnat" snatSuffix = "_snat" @@ -115,6 +117,10 @@ func (r *router) init(stateManager *statemanager.Manager) error { return fmt.Errorf("create containers: %w", err) } + if err := r.setupDataPlaneMark(); err != nil { + log.Errorf("failed to set up data plane mark: %v", err) + } + r.updateState() return nil @@ -348,12 +354,16 @@ func (r *router) Reset() error { if err := r.cleanUpDefaultForwardRules(); err != nil { merr = multierror.Append(merr, err) } - r.rules = make(map[string][]string) if err := r.ipsetCounter.Flush(); err != nil { merr = multierror.Append(merr, err) } + if err := r.cleanupDataPlaneMark(); err != nil { + merr = multierror.Append(merr, err) + } + + r.rules = make(map[string][]string) r.updateState() return nberrors.FormatErrorOrNil(merr) @@ -423,6 +433,57 @@ func (r *router) createContainers() error { return nil } +// setupDataPlaneMark configures the fwmark for the data plane +func (r *router) setupDataPlaneMark() error { + var merr *multierror.Error + preRule := []string{ + "-i", r.wgIface.Name(), + "-m", "conntrack", "--ctstate", "NEW", + "-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn), + } + + if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err)) + } else { + r.rules[markManglePre] = preRule + } + + postRule := []string{ + "-o", r.wgIface.Name(), + "-m", "conntrack", "--ctstate", "NEW", + "-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut), + } + + if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err)) + } else { + r.rules[markManglePost] = postRule + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) cleanupDataPlaneMark() error { + var merr *multierror.Error + if preRule, exists := r.rules[markManglePre]; exists { + if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err)) + } else { + delete(r.rules, markManglePre) + } + } + + if postRule, exists := r.rules[markManglePost]; exists { + if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err)) + } else { + delete(r.rules, markManglePost) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + func (r *router) addPostroutingRules() error { // First rule for outbound masquerade rule1 := []string{ @@ -464,7 +525,7 @@ func (r *router) insertEstablishedRule(chain string) error { } func (r *router) addJumpRules() error { - // Jump to NAT chain + // Jump to nat chain natRule := []string{"-j", chainRTNAT} if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil { return fmt.Errorf("add nat postrouting jump rule: %v", err) diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index c039f3674..dad77dee7 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -46,7 +46,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { // 5. jump rule to PRE nat chain // 6. static outbound masquerade rule // 7. static return masquerade rule - require.Len(t, manager.rules, 7, "should have created rules map") + // 8. mangle prerouting mark rule + // 9. mangle postrouting mark rule + require.Len(t, manager.rules, 9, "should have created rules map") exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 24ffe3386..b6e9a930b 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -25,9 +25,10 @@ const ( chainNameInputRules = "netbird-acl-input-rules" // filter chains contains the rules that jump to the rules chains - chainNameInputFilter = "netbird-acl-input-filter" - chainNameForwardFilter = "netbird-acl-forward-filter" - chainNamePrerouting = "netbird-rt-prerouting" + chainNameInputFilter = "netbird-acl-input-filter" + chainNameForwardFilter = "netbird-acl-forward-filter" + chainNameManglePrerouting = "netbird-mangle-prerouting" + chainNameManglePostrouting = "netbird-mangle-postrouting" allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) @@ -462,13 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) { // go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the // netbird peer IP. func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { - m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{ - Name: chainNamePrerouting, + // Chain is created by route manager + // TODO: move creation to a common place + m.chainPrerouting = &nftables.Chain{ + Name: chainNameManglePrerouting, Table: m.workTable, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookPrerouting, Priority: nftables.ChainPriorityMangle, - }) + } m.addFwmarkToForward(chainFwFilter) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 6def30bf0..aff86dd90 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -100,6 +100,10 @@ func (r *router) init(workTable *nftables.Table) error { return fmt.Errorf("create containers: %w", err) } + if err := r.setupDataPlaneMark(); err != nil { + log.Errorf("failed to set up data plane mark: %v", err) + } + return nil } @@ -196,15 +200,21 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeNAT, }) - // Chain is created by acl manager - // TODO: move creation to a common place - r.chains[chainNamePrerouting] = &nftables.Chain{ - Name: chainNamePrerouting, + r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameManglePostrouting, + Table: r.workTable, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeFilter, + }) + + r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameManglePrerouting, Table: r.workTable, Hooknum: nftables.ChainHookPrerouting, Priority: nftables.ChainPriorityMangle, Type: nftables.ChainTypeFilter, - } + }) // Add the single NAT rule that matches on mark if err := r.addPostroutingRules(); err != nil { @@ -220,7 +230,83 @@ func (r *router) createContainers() error { } if err := r.conn.Flush(); err != nil { - return fmt.Errorf("nftables: unable to initialize table: %v", err) + return fmt.Errorf("initialize tables: %v", err) + } + + return nil +} + +// setupDataPlaneMark configures the fwmark for the data plane +func (r *router) setupDataPlaneMark() error { + if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil { + return errors.New("no mangle chains found") + } + + ctNew := getCtNewExprs() + preExprs := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + } + preExprs = append(preExprs, ctNew...) + preExprs = append(preExprs, + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn), + }, + &expr.Ct{ + Key: expr.CtKeyMARK, + Register: 1, + SourceRegister: true, + }, + ) + + preNftRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameManglePrerouting], + Exprs: preExprs, + } + r.conn.AddRule(preNftRule) + + postExprs := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + } + postExprs = append(postExprs, ctNew...) + postExprs = append(postExprs, + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut), + }, + &expr.Ct{ + Key: expr.CtKeyMARK, + Register: 1, + SourceRegister: true, + }, + ) + + postNftRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameManglePostrouting], + Exprs: postExprs, + } + r.conn.AddRule(postNftRule) + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush: %w", err) } return nil @@ -516,26 +602,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { op = expr.CmpOpNeq } - exprs := []expr.Any{ - // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. - // Masquerading will take care of the conntrack state, which means we won't need to mark established connections. - &expr.Ct{ - Key: expr.CtKeySTATE, - Register: 1, - }, - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW), - Xor: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, - + // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading. + // Masquerading will take care of the conntrack state, which means we won't need to mark established connections. + exprs := getCtNewExprs() + exprs = append(exprs, // interface matching &expr.Meta{ Key: expr.MetaKeyIIFNAME, @@ -546,7 +616,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { Register: 1, Data: ifname(r.wgIface.Name()), }, - } + ) exprs = append(exprs, sourceExp...) exprs = append(exprs, destExp...) @@ -578,7 +648,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ Table: r.workTable, - Chain: r.chains[chainNamePrerouting], + Chain: r.chains[chainNameManglePrerouting], Exprs: exprs, UserData: []byte(ruleKey), }) @@ -1324,3 +1394,24 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any { return exprs } + +func getCtNewExprs() []expr.Any { + return []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + } +} diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 498fdf882..28baef4dd 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) { natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) found := 0 for _, chain := range rtr.chains { - if chain.Name == chainNamePrerouting { + if chain.Name == chainNameManglePrerouting { rules, err := nftablesTestingClient.GetRules(chain.Table, chain) require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) for _, rule := range rules { @@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { // Verify the rule was added natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) found := false - rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) + rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting]) require.NoError(t, err, "should list rules") for _, rule := range rules { if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { @@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { // Verify the rule was removed found = false - rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) + rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting]) require.NoError(t, err, "should list rules after removal") for _, rule := range rules { if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index a3de58c24..e536f2650 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -357,7 +357,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string { func getFwmark() int { if nbnet.AdvancedRouting() { - return nbnet.NetbirdFwmark + return nbnet.ControlPlaneMark } return 0 } diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index c0614406d..f8440b913 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -14,6 +14,7 @@ import ( "github.com/ti-mo/netfilter" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" + nbnet "github.com/netbirdio/netbird/util/net" ) const defaultChannelSize = 100 @@ -176,7 +177,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) { srcIP := flow.TupleOrig.IP.SourceAddress dstIP := flow.TupleOrig.IP.DestinationAddress - if !c.relevantFlow(srcIP, dstIP) { + if !c.relevantFlow(flow.Mark, srcIP, dstIP) { return } @@ -193,7 +194,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) { } flowID := c.getFlowID(flow.ID) - direction := c.inferDirection(srcIP, dstIP) + direction := c.inferDirection(flow.Mark, srcIP, dstIP) eventType := nftypes.TypeStart eventStr := "New" @@ -224,15 +225,14 @@ func (c *ConnTrack) handleEvent(event nfct.Event) { } // relevantFlow checks if the flow is related to the specified interface -func (c *ConnTrack) relevantFlow(srcIP, dstIP netip.Addr) bool { - // TODO: filter traffic by interface - - wgnet := c.iface.Address().Network - if !wgnet.Contains(srcIP.AsSlice()) && !wgnet.Contains(dstIP.AsSlice()) { - return false +func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool { + if nbnet.IsDataPlaneMark(mark) { + return true } - return true + // fallback if mark rules are not in place + wgnet := c.iface.Address().Network + return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice()) } // mapRxPackets maps packet counts to RX based on flow direction @@ -282,7 +282,15 @@ func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID { return uuid.NewSHA1(c.instanceID, buf[:]) } -func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction { +func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes.Direction { + switch mark { + case nbnet.DataPlaneMarkIn: + return nftypes.Ingress + case nbnet.DataPlaneMarkOut: + return nftypes.Egress + } + + // fallback if marks are not set wgaddr := c.iface.Address().IP wgnetwork := c.iface.Address().Network src, dst := srcIP.AsSlice(), dstIP.AsSlice() @@ -298,8 +306,6 @@ func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction { case wgnetwork.Contains(dst): // resource network -> netbird network return nftypes.Egress - - // TODO: handle site2site traffic } return nftypes.DirectionUnknown diff --git a/client/internal/netflow/types/types.go b/client/internal/netflow/types/types.go index ea752131b..f76146ba3 100644 --- a/client/internal/netflow/types/types.go +++ b/client/internal/netflow/types/types.go @@ -10,6 +10,8 @@ import ( "github.com/netbirdio/netbird/client/iface/wgaddr" ) +const ZoneID = 0x1BD0 + type Protocol uint8 const ( diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index d724cb1a7..cf3c2f0aa 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -57,8 +57,8 @@ func getSetupRules() []ruleParams { return []ruleParams{ {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"}, {100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"}, - {110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"}, - {110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"}, + {110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"}, + {110, nbnet.ControlPlaneMark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"}, } } diff --git a/util/net/env_linux.go b/util/net/env_linux.go index 124bf64de..3159f6462 100644 --- a/util/net/env_linux.go +++ b/util/net/env_linux.go @@ -88,9 +88,21 @@ func CheckFwmarkSupport() bool { log.Warnf("failed to dial with fwmark: %v", err) return false } - if err := conn.Close(); err != nil { - log.Warnf("failed to close connection: %v", err) + defer func() { + if err := conn.Close(); err != nil { + log.Warnf("failed to close connection: %v", err) + } + }() + + if err := conn.SetWriteDeadline(time.Now().Add(time.Millisecond * 100)); err != nil { + log.Warnf("failed to set write deadline: %v", err) + return false + } + + if _, err := conn.Write([]byte("")); err != nil { + log.Warnf("failed to write to fwmark connection: %v", err) + return false } return true diff --git a/util/net/net.go b/util/net/net.go index 7b43b952f..b573f9aeb 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -8,14 +8,40 @@ import ( ) const ( - // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 + // ControlPlaneMark is the fwmark value used to mark packets that should not be routed through the NetBird interface to + // avoid routing loops. + // This includes all control plane traffic (mgmt, signal, flows), relay, ICE/stun/turn and everything that is emitted by the wireguard socket. + // It doesn't collide with the other marks, as the others are used for data plane traffic only. + ControlPlaneMark = 0x1BD00 - PreroutingFwmarkRedirected = 0x1BD01 - PreroutingFwmarkMasquerade = 0x1BD11 - PreroutingFwmarkMasqueradeReturn = 0x1BD12 + // Data plane marks (0x1BD10 - 0x1BDFF) + + // DataPlaneMarkLower is the lowest value for the data plane range + DataPlaneMarkLower = 0x1BD10 + // DataPlaneMarkUpper is the highest value for the data plane range + DataPlaneMarkUpper = 0x1BDFF + + // DataPlaneMarkIn is the mark for inbound data plane traffic. + DataPlaneMarkIn = 0x1BD10 + + // DataPlaneMarkOut is the mark for outbound data plane traffic. + DataPlaneMarkOut = 0x1BD11 + + // PreroutingFwmarkRedirected is applied to packets that are were redirected (input -> forward, e.g. by Docker or Podman) for special handling. + PreroutingFwmarkRedirected = 0x1BD20 + + // PreroutingFwmarkMasquerade is applied to packets that arrive from the NetBird interface and should be masqueraded. + PreroutingFwmarkMasquerade = 0x1BD21 + + // PreroutingFwmarkMasqueradeReturn is applied to packets that will leave through the NetBird interface and should be masqueraded. + PreroutingFwmarkMasqueradeReturn = 0x1BD22 ) +// IsDataPlaneMark determines if a fwmark is in the data plane range (0x1BD10-0x1BDFF) +func IsDataPlaneMark(fwmark uint32) bool { + return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper +} + // ConnectionID provides a globally unique identifier for network connections. // It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. type ConnectionID string diff --git a/util/net/net_linux.go b/util/net/net_linux.go index eae483a26..9e7d13702 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -51,5 +51,5 @@ func setRawSocketMark(conn syscall.RawConn) error { } func setSocketOptInt(fd int) error { - return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) + return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, ControlPlaneMark) }