diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index d3441c69a..e9dfbd7ab 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -16,6 +16,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" + "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" @@ -76,6 +77,7 @@ type router struct { legacyManagement bool stateManager *statemanager.Manager + ipFwdState *ipfwdstate.IPForwardingState } func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { @@ -83,6 +85,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, iptablesClient: iptablesClient, rules: make(map[string][]string), wgIface: wgIface, + ipFwdState: ipfwdstate.NewIPForwardingState(), } r.ipsetCounter = refcounter.New( @@ -217,6 +220,10 @@ func (r *router) deleteIpSet(setName string) error { // AddNatRule inserts an iptables rule pair into the nat chain func (r *router) AddNatRule(pair firewall.RouterPair) error { + if err := r.ipFwdState.RequestForwarding(); err != nil { + return err + } + if r.legacyManagement { log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) if err := r.addLegacyRouteRule(pair); err != nil { @@ -243,6 +250,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { // RemoveNatRule removes an iptables rule pair from forwarding and nat chains func (r *router) RemoveNatRule(pair firewall.RouterPair) error { + if err := r.ipFwdState.ReleaseForwarding(); err != nil { + log.Errorf("%v", err) + } + if err := r.removeNatRule(pair); err != nil { return fmt.Errorf("remove nat rule: %w", err) } @@ -575,6 +586,10 @@ func (r *router) updateState() { } func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { + if err := r.ipFwdState.RequestForwarding(); err != nil { + return nil, err + } + ruleKey := rule.ID() if _, exists := r.rules[ruleKey+dnatSuffix]; exists { return rule, nil @@ -669,6 +684,10 @@ func (r *router) rollbackRules(rules map[string]ruleInfo) error { } func (r *router) DeleteDNATRule(rule firewall.Rule) error { + if err := r.ipFwdState.ReleaseForwarding(); err != nil { + log.Errorf("%v", err) + } + ruleKey := rule.ID() var merr *multierror.Error diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 3a96ea39b..6f7ebde5a 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -21,6 +21,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" + "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -56,16 +57,18 @@ type router struct { ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] wgIface iFaceMapper + ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool } func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { r := &router{ - conn: &nftables.Conn{}, - workTable: workTable, - chains: make(map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), - wgIface: wgIface, + conn: &nftables.Conn{}, + workTable: workTable, + chains: make(map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + wgIface: wgIface, + ipFwdState: ipfwdstate.NewIPForwardingState(), } r.ipsetCounter = refcounter.New( @@ -464,6 +467,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { // AddNatRule appends a nftables rule pair to the nat chain func (r *router) AddNatRule(pair firewall.RouterPair) error { + if err := r.ipFwdState.RequestForwarding(); err != nil { + return err + } + if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } @@ -890,6 +897,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error // RemoveNatRule removes the prerouting mark rule func (r *router) RemoveNatRule(pair firewall.RouterPair) error { + if err := r.ipFwdState.ReleaseForwarding(); err != nil { + log.Errorf("%v", err) + } + if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } @@ -951,6 +962,10 @@ func (r *router) refreshRulesMap() error { } func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { + if err := r.ipFwdState.RequestForwarding(); err != nil { + return nil, err + } + ruleKey := rule.ID() if _, exists := r.rules[ruleKey+dnatSuffix]; exists { return rule, nil @@ -1174,6 +1189,10 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey } func (r *router) DeleteDNATRule(rule firewall.Rule) error { + if err := r.ipFwdState.ReleaseForwarding(); err != nil { + log.Errorf("%v", err) + } + ruleKey := rule.ID() if err := r.refreshRulesMap(); err != nil { diff --git a/client/internal/routemanager/ipfwdstate/ipfwdstate.go b/client/internal/routemanager/ipfwdstate/ipfwdstate.go new file mode 100644 index 000000000..da81c18f9 --- /dev/null +++ b/client/internal/routemanager/ipfwdstate/ipfwdstate.go @@ -0,0 +1,51 @@ +package ipfwdstate + +import ( + "fmt" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +// IPForwardingState is a struct that keeps track of the IP forwarding state. +// todo: read initial state of the IP forwarding from the system and reset the state based on it +type IPForwardingState struct { + enabledCounter int +} + +func NewIPForwardingState() *IPForwardingState { + return &IPForwardingState{} +} + +func (f *IPForwardingState) RequestForwarding() error { + if f.enabledCounter != 0 { + f.enabledCounter++ + return nil + } + + if err := systemops.EnableIPForwarding(); err != nil { + return fmt.Errorf("failed to enable IP forwarding with sysctl: %w", err) + } + f.enabledCounter = 1 + log.Info("IP forwarding enabled") + + return nil +} + +func (f *IPForwardingState) ReleaseForwarding() error { + if f.enabledCounter == 0 { + return nil + } + + if f.enabledCounter > 1 { + f.enabledCounter-- + return nil + } + + // if failed to disable IP forwarding we anyway decrement the counter + f.enabledCounter = 0 + + // todo call systemops.DisableIPForwarding() + return nil +} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index b60cb318e..6ff80e52d 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -13,7 +13,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/route" ) @@ -70,13 +69,6 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { m.routes[id] = newRoute } - if len(m.routes) > 0 { - err := systemops.EnableIPForwarding() - if err != nil { - return err - } - } - return nil }