mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 01:38:41 +02:00
[client] Enable userspace forwarder conditionally (#3309)
* Enable userspace forwarder conditionally * Move disable/enable logic
This commit is contained in:
parent
18f84f0df5
commit
b41de7fcd1
@ -218,6 +218,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
// not supported
|
// not supported
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
@ -101,6 +101,10 @@ type Manager interface {
|
|||||||
Flush() error
|
Flush() error
|
||||||
|
|
||||||
SetLogLevel(log.Level)
|
SetLogLevel(log.Level)
|
||||||
|
|
||||||
|
EnableRouting() error
|
||||||
|
|
||||||
|
DisableRouting() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
|
@ -323,6 +323,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
// not supported
|
// not supported
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
// Flush rule/chain/set operations from the buffer
|
||||||
//
|
//
|
||||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
|
@ -74,6 +74,8 @@ type Manager struct {
|
|||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
// indicates whether server routes are disabled
|
||||||
|
disableServerRoutes bool
|
||||||
// indicates whether we forward packets not destined for ourselves
|
// indicates whether we forward packets not destined for ourselves
|
||||||
routingEnabled bool
|
routingEnabled bool
|
||||||
// indicates whether we leave forwarding and filtering to the native firewall
|
// indicates whether we leave forwarding and filtering to the native firewall
|
||||||
@ -125,15 +127,27 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
|
|||||||
return mgr, nil
|
return mgr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseCreateEnv() (bool, bool) {
|
||||||
|
var disableConntrack, enableLocalForwarding bool
|
||||||
|
var err error
|
||||||
|
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||||
|
disableConntrack, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv(EnvEnableNetstackLocalForwarding); val != "" {
|
||||||
|
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return disableConntrack, enableLocalForwarding
|
||||||
|
}
|
||||||
|
|
||||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
||||||
disableConntrack, err := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
|
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err)
|
|
||||||
}
|
|
||||||
enableLocalForwarding, err := strconv.ParseBool(os.Getenv(EnvEnableNetstackLocalForwarding))
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
@ -149,15 +163,16 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
nativeFirewall: nativeFirewall,
|
nativeFirewall: nativeFirewall,
|
||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[string]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[string]RuleSet),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
localipmanager: newLocalIPManager(),
|
localipmanager: newLocalIPManager(),
|
||||||
routingEnabled: false,
|
disableServerRoutes: disableServerRoutes,
|
||||||
stateful: !disableConntrack,
|
routingEnabled: false,
|
||||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
stateful: !disableConntrack,
|
||||||
netstack: netstack.IsEnabled(),
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
|
netstack: netstack.IsEnabled(),
|
||||||
// default true for non-netstack, for netstack only if explicitly enabled
|
// default true for non-netstack, for netstack only if explicitly enabled
|
||||||
localForwarding: !netstack.IsEnabled() || enableLocalForwarding,
|
localForwarding: !netstack.IsEnabled() || enableLocalForwarding,
|
||||||
}
|
}
|
||||||
@ -166,7 +181,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only initialize trackers if stateful mode is enabled
|
|
||||||
if disableConntrack {
|
if disableConntrack {
|
||||||
log.Info("conntrack is disabled")
|
log.Info("conntrack is disabled")
|
||||||
} else {
|
} else {
|
||||||
@ -175,7 +189,12 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.determineRouting(iface, disableServerRoutes)
|
// netstack needs the forwarder for local traffic
|
||||||
|
if m.netstack && m.localForwarding {
|
||||||
|
if err := m.initForwarder(); err != nil {
|
||||||
|
log.Errorf("failed to initialize forwarder: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := m.blockInvalidRouted(iface); err != nil {
|
if err := m.blockInvalidRouted(iface); err != nil {
|
||||||
log.Errorf("failed to block invalid routed traffic: %v", err)
|
log.Errorf("failed to block invalid routed traffic: %v", err)
|
||||||
@ -213,9 +232,21 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) {
|
func (m *Manager) determineRouting() error {
|
||||||
disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting))
|
var disableUspRouting, forceUserspaceRouter bool
|
||||||
forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter))
|
var err error
|
||||||
|
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
||||||
|
disableUspRouting, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
|
||||||
|
forceUserspaceRouter, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case disableUspRouting:
|
case disableUspRouting:
|
||||||
@ -223,7 +254,7 @@ func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes
|
|||||||
m.nativeRouter = false
|
m.nativeRouter = false
|
||||||
log.Info("userspace routing is disabled")
|
log.Info("userspace routing is disabled")
|
||||||
|
|
||||||
case disableServerRoutes:
|
case m.disableServerRoutes:
|
||||||
// if server routes are disabled we will let packets pass to the native stack
|
// if server routes are disabled we will let packets pass to the native stack
|
||||||
m.routingEnabled = true
|
m.routingEnabled = true
|
||||||
m.nativeRouter = true
|
m.nativeRouter = true
|
||||||
@ -252,32 +283,37 @@ func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes
|
|||||||
log.Info("userspace routing enabled by default")
|
log.Info("userspace routing enabled by default")
|
||||||
}
|
}
|
||||||
|
|
||||||
// netstack needs the forwarder for local traffic
|
if m.routingEnabled && !m.nativeRouter {
|
||||||
if m.netstack && m.localForwarding ||
|
return m.initForwarder()
|
||||||
m.routingEnabled && !m.nativeRouter {
|
|
||||||
|
|
||||||
m.initForwarder(iface)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// initForwarder initializes the forwarder, it disables routing on errors
|
// initForwarder initializes the forwarder, it disables routing on errors
|
||||||
func (m *Manager) initForwarder(iface common.IFaceMapper) {
|
func (m *Manager) initForwarder() error {
|
||||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
if m.forwarder != nil {
|
||||||
intf := iface.GetWGDevice()
|
return nil
|
||||||
if intf == nil {
|
|
||||||
log.Info("forwarding not supported")
|
|
||||||
m.routingEnabled = false
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder, err := forwarder.New(iface, m.logger, m.netstack)
|
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||||
if err != nil {
|
intf := m.wgIface.GetWGDevice()
|
||||||
log.Errorf("failed to create forwarder: %v", err)
|
if intf == nil {
|
||||||
m.routingEnabled = false
|
m.routingEnabled = false
|
||||||
return
|
return errors.New("forwarding not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
||||||
|
if err != nil {
|
||||||
|
m.routingEnabled = false
|
||||||
|
return fmt.Errorf("create forwarder: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.forwarder = forwarder
|
m.forwarder = forwarder
|
||||||
|
|
||||||
|
log.Debug("forwarder initialized")
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Init(*statemanager.Manager) error {
|
func (m *Manager) Init(*statemanager.Manager) error {
|
||||||
@ -285,7 +321,7 @@ func (m *Manager) Init(*statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
return m.nativeFirewall != nil || m.routingEnabled && m.forwarder != nil
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
@ -586,7 +622,6 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
|||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
if !m.isValidPacket(d, packetData) {
|
if !m.isValidPacket(d, packetData) {
|
||||||
m.logger.Trace("Invalid packet structure")
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -658,11 +693,9 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get protocol and ports for route ACL check
|
|
||||||
proto := getProtocolFromPacket(d)
|
proto := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
// Check route ACLs
|
|
||||||
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
||||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
||||||
srcIP, srcPort, dstIP, dstPort, proto)
|
srcIP, srcPort, dstIP, dstPort, proto)
|
||||||
@ -704,12 +737,12 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
|||||||
|
|
||||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
log.Tracef("couldn't decode layer, err: %s", err)
|
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(d.decoded) < 2 {
|
if len(d.decoded) < 2 {
|
||||||
log.Tracef("not enough levels in network packet")
|
m.logger.Trace("packet doesn't have network and transport layers")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
@ -953,3 +986,34 @@ func (m *Manager) SetLogLevel(level log.Level) {
|
|||||||
m.logger.SetLevel(nblog.Level(level))
|
m.logger.SetLevel(nblog.Level(level))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.determineRouting()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.forwarder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routingEnabled = false
|
||||||
|
m.nativeRouter = false
|
||||||
|
|
||||||
|
// don't stop forwarder if in use by netstack
|
||||||
|
if m.netstack && m.localForwarding {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.forwarder.Stop()
|
||||||
|
m.forwarder = nil
|
||||||
|
|
||||||
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -303,6 +303,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false)
|
||||||
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
require.True(tb, manager.routingEnabled)
|
require.True(tb, manager.routingEnabled)
|
||||||
|
@ -286,15 +286,15 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||||
}
|
}
|
||||||
|
m.clientRoutes = newClientRoutesIDMap
|
||||||
|
|
||||||
if m.serverRouter != nil {
|
if m.serverRouter == nil {
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
return nil
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.clientRoutes = newClientRoutesIDMap
|
if err := m.serverRouter.updateRoutes(newServerRoutesMap); err != nil {
|
||||||
|
return fmt.Errorf("update routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -71,9 +71,15 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(m.routes) > 0 {
|
if len(m.routes) > 0 {
|
||||||
err := systemops.EnableIPForwarding()
|
if err := systemops.EnableIPForwarding(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("enable ip forwarding: %w", err)
|
||||||
return err
|
}
|
||||||
|
if err := m.firewall.EnableRouting(); err != nil {
|
||||||
|
return fmt.Errorf("enable routing: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := m.firewall.DisableRouting(); err != nil {
|
||||||
|
return fmt.Errorf("disable routing: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user