diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 679f288e3..929e8a656 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -218,6 +218,14 @@ func (m *Manager) SetLogLevel(log.Level) { // not supported } +func (m *Manager) EnableRouting() error { + return nil +} + +func (m *Manager) DisableRouting() error { + return nil +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index de25ff1f1..d007e20a5 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -101,6 +101,10 @@ type Manager interface { Flush() error SetLogLevel(log.Level) + + EnableRouting() error + + DisableRouting() error } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 4fe52bd53..de68f3291 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -323,6 +323,14 @@ func (m *Manager) SetLogLevel(log.Level) { // not supported } +func (m *Manager) EnableRouting() error { + return nil +} + +func (m *Manager) DisableRouting() error { + return nil +} + // Flush rule/chain/set operations from the buffer // // Method also get all rules after flush and refreshes handle values in the rulesets diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 889e4cbb1..5bb225ccd 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -74,6 +74,8 @@ type Manager struct { mutex sync.RWMutex + // indicates whether server routes are disabled + disableServerRoutes bool // indicates whether we forward packets not destined for ourselves routingEnabled bool // indicates whether we leave forwarding and filtering to the native firewall @@ -125,15 +127,27 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall. return mgr, nil } +func parseCreateEnv() (bool, bool) { + var disableConntrack, enableLocalForwarding bool + var err error + if val := os.Getenv(EnvDisableConntrack); val != "" { + disableConntrack, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err) + } + } + if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" { + enableLocalForwarding, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err) + } + } + + return disableConntrack, enableLocalForwarding +} + func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { - disableConntrack, err := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) - if err != nil { - log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err) - } - enableLocalForwarding, err := strconv.ParseBool(os.Getenv(EnvEnableNetstackLocalForwarding)) - if err != nil { - log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err) - } + disableConntrack, enableLocalForwarding := parseCreateEnv() m := &Manager{ decoders: sync.Pool{ @@ -149,15 +163,16 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe return d }, }, - nativeFirewall: nativeFirewall, - outgoingRules: make(map[string]RuleSet), - incomingRules: make(map[string]RuleSet), - wgIface: iface, - localipmanager: newLocalIPManager(), - routingEnabled: false, - stateful: !disableConntrack, - logger: nblog.NewFromLogrus(log.StandardLogger()), - netstack: netstack.IsEnabled(), + nativeFirewall: nativeFirewall, + outgoingRules: make(map[string]RuleSet), + incomingRules: make(map[string]RuleSet), + wgIface: iface, + localipmanager: newLocalIPManager(), + disableServerRoutes: disableServerRoutes, + routingEnabled: false, + stateful: !disableConntrack, + logger: nblog.NewFromLogrus(log.StandardLogger()), + netstack: netstack.IsEnabled(), // default true for non-netstack, for netstack only if explicitly enabled localForwarding: !netstack.IsEnabled() || enableLocalForwarding, } @@ -166,7 +181,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe return nil, fmt.Errorf("update local IPs: %w", err) } - // Only initialize trackers if stateful mode is enabled if disableConntrack { log.Info("conntrack is disabled") } else { @@ -175,7 +189,12 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) } - m.determineRouting(iface, disableServerRoutes) + // netstack needs the forwarder for local traffic + if m.netstack && m.localForwarding { + if err := m.initForwarder(); err != nil { + log.Errorf("failed to initialize forwarder: %v", err) + } + } if err := m.blockInvalidRouted(iface); err != nil { log.Errorf("failed to block invalid routed traffic: %v", err) @@ -213,9 +232,21 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { return nil } -func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) { - disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting)) - forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter)) +func (m *Manager) determineRouting() error { + var disableUspRouting, forceUserspaceRouter bool + var err error + if val := os.Getenv(EnvDisableUserspaceRouting); val != "" { + disableUspRouting, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err) + } + } + if val := os.Getenv(EnvForceUserspaceRouter); val != "" { + forceUserspaceRouter, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err) + } + } switch { case disableUspRouting: @@ -223,7 +254,7 @@ func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes m.nativeRouter = false log.Info("userspace routing is disabled") - case disableServerRoutes: + case m.disableServerRoutes: // if server routes are disabled we will let packets pass to the native stack m.routingEnabled = true m.nativeRouter = true @@ -252,32 +283,37 @@ func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes log.Info("userspace routing enabled by default") } - // netstack needs the forwarder for local traffic - if m.netstack && m.localForwarding || - m.routingEnabled && !m.nativeRouter { - - m.initForwarder(iface) + if m.routingEnabled && !m.nativeRouter { + return m.initForwarder() } + + return nil } // initForwarder initializes the forwarder, it disables routing on errors -func (m *Manager) initForwarder(iface common.IFaceMapper) { - // Only supported in userspace mode as we need to inject packets back into wireguard directly - intf := iface.GetWGDevice() - if intf == nil { - log.Info("forwarding not supported") - m.routingEnabled = false - return +func (m *Manager) initForwarder() error { + if m.forwarder != nil { + return nil } - forwarder, err := forwarder.New(iface, m.logger, m.netstack) - if err != nil { - log.Errorf("failed to create forwarder: %v", err) + // Only supported in userspace mode as we need to inject packets back into wireguard directly + intf := m.wgIface.GetWGDevice() + if intf == nil { m.routingEnabled = false - return + return errors.New("forwarding not supported") + } + + forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack) + if err != nil { + m.routingEnabled = false + return fmt.Errorf("create forwarder: %w", err) } m.forwarder = forwarder + + log.Debug("forwarder initialized") + + return nil } func (m *Manager) Init(*statemanager.Manager) error { @@ -285,7 +321,7 @@ func (m *Manager) Init(*statemanager.Manager) error { } func (m *Manager) IsServerRouteSupported() bool { - return m.nativeFirewall != nil || m.routingEnabled && m.forwarder != nil + return true } func (m *Manager) AddNatRule(pair firewall.RouterPair) error { @@ -586,7 +622,6 @@ func (m *Manager) dropFilter(packetData []byte) bool { defer m.decoders.Put(d) if !m.isValidPacket(d, packetData) { - m.logger.Trace("Invalid packet structure") return true } @@ -658,11 +693,9 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat return false } - // Get protocol and ports for route ACL check proto := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) - // Check route ACLs if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) { m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v", srcIP, srcPort, dstIP, dstPort, proto) @@ -704,12 +737,12 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - log.Tracef("couldn't decode layer, err: %s", err) + m.logger.Trace("couldn't decode packet, err: %s", err) return false } if len(d.decoded) < 2 { - log.Tracef("not enough levels in network packet") + m.logger.Trace("packet doesn't have network and transport layers") return false } return true @@ -953,3 +986,34 @@ func (m *Manager) SetLogLevel(level log.Level) { m.logger.SetLevel(nblog.Level(level)) } } + +func (m *Manager) EnableRouting() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.determineRouting() +} + +func (m *Manager) DisableRouting() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.forwarder == nil { + return nil + } + + m.routingEnabled = false + m.nativeRouter = false + + // don't stop forwarder if in use by netstack + if m.netstack && m.localForwarding { + return nil + } + + m.forwarder.Stop() + m.forwarder = nil + + log.Debug("forwarder stopped") + + return nil +} diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index d7aebb1aa..9a1456d00 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -303,6 +303,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { } manager, err := Create(ifaceMock, false) + require.NoError(tb, manager.EnableRouting()) require.NoError(tb, err) require.NotNil(tb, manager) require.True(tb, manager.routingEnabled) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 9c7f1f6fa..52de0948b 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -286,15 +286,15 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro m.updateClientNetworks(updateSerial, filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes) } + m.clientRoutes = newClientRoutesIDMap - if m.serverRouter != nil { - err := m.serverRouter.updateRoutes(newServerRoutesMap) - if err != nil { - return err - } + if m.serverRouter == nil { + return nil } - m.clientRoutes = newClientRoutesIDMap + if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil { + return fmt.Errorf("update routes: %w", err) + } return nil } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index b60cb318e..4690e3f0e 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -71,9 +71,15 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { } if len(m.routes) > 0 { - err := systemops.EnableIPForwarding() - if err != nil { - return err + if err := systemops.EnableIPForwarding(); err != nil { + return fmt.Errorf("enable ip forwarding: %w", err) + } + if err := m.firewall.EnableRouting(); err != nil { + return fmt.Errorf("enable routing: %w", err) + } + } else { + if err := m.firewall.DisableRouting(); err != nil { + return fmt.Errorf("disable routing: %w", err) } }