diff --git a/client/firewall/create.go b/client/firewall/create.go index 9466f4b4d..37ea5ceb3 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -14,13 +14,13 @@ import ( ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } // use userspace packet filtering firewall - fm, err := uspfilter.Create(iface) + fm, err := uspfilter.Create(iface, disableServerRoutes) if err != nil { return nil, err } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 076d08ec2..be1b37916 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - fm, err := createNativeFirewall(iface, stateManager) + fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) if !iface.IsUserspaceBind() { return fm, err @@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } - return createUserspaceFirewall(iface, fm) + return createUserspaceFirewall(iface, fm, disableServerRoutes) } -func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { fm, err := createFW(iface) if err != nil { return nil, fmt.Errorf("create firewall: %s", err) @@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) { } } -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) { var errUsp error if fm != nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes) } else { - fm, errUsp = uspfilter.Create(iface) + fm, errUsp = uspfilter.Create(iface, disableServerRoutes) } if errUsp != nil { diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 0da518c9f..7459a9a49 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -87,18 +87,23 @@ type decoder struct { } // Create userspace firewall manager constructor -func Create(iface common.IFaceMapper) (*Manager, error) { - return create(iface) +func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) { + return create(iface, disableServerRoutes) } -func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) { - mgr, err := create(iface) +func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { + mgr, err := create(iface, disableServerRoutes) if err != nil { return nil, err } mgr.nativeFirewall = nativeFirewall + if disableServerRoutes { + // skip native vs userspace router logic altogether + return mgr, nil + } + if forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter)); forceUserspaceRouter { log.Info("userspace routing is forced") return mgr, nil @@ -125,7 +130,7 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall. return mgr, nil } -func create(iface common.IFaceMapper) (*Manager, error) { +func create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) { disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) m := &Manager{ @@ -147,6 +152,7 @@ func create(iface common.IFaceMapper) (*Manager, error) { routeRules: make(map[string]RouteRule), wgIface: iface, localipmanager: newLocalIPManager(), + routingEnabled: false, stateful: !disableConntrack, // TODO: support changing log level from logrus logger: nblog.NewFromLogrus(log.StandardLogger()), @@ -166,23 +172,16 @@ func create(iface common.IFaceMapper) (*Manager, error) { m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) } - if disableRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting)); disableRouting { + disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting)) + if disableUspRouting || disableServerRoutes { log.Info("userspace routing is disabled") - return m, nil + } else { + m.routingEnabled = true } - intf := iface.GetWGDevice() - if intf == nil { - log.Info("forwarding not supported") - // Only supported in userspace mode as we need to inject packets back into wireguard directly - } else { - var err error - m.forwarder, err = forwarder.New(iface, m.logger, m.netstack) - if err != nil { - log.Errorf("failed to create forwarder: %v", err) - } else { - m.routingEnabled = true - } + // netstack needs the forwarder for local traffic + if m.netstack || m.routingEnabled { + m.initForwarder(iface) } if err := iface.SetFilter(m); err != nil { @@ -191,6 +190,25 @@ func create(iface common.IFaceMapper) (*Manager, error) { return m, nil } +func (m *Manager) initForwarder(iface common.IFaceMapper) { + // Only supported in userspace mode as we need to inject packets back into wireguard directly + intf := iface.GetWGDevice() + if intf == nil { + log.Info("forwarding not supported") + m.routingEnabled = false + return + } + + forwarder, err := forwarder.New(iface, m.logger, m.netstack) + if err != nil { + log.Errorf("failed to create forwarder: %v", err) + m.routingEnabled = false + return + } + + m.forwarder = forwarder +} + func (m *Manager) Init(*statemanager.Manager) error { return nil } @@ -509,8 +527,6 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { // dropFilter implements filtering logic for incoming packets. // If it returns true, the packet should be dropped. func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { - // TODO: Disable router if --disable-server-router is set - m.mutex.RLock() defer m.mutex.RUnlock() diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index c3974b492..827cc1f0c 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -158,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Create manager and basic setup manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -203,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -251,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -450,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -577,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -668,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -787,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -875,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index 93b947148..73209a152 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -34,7 +34,7 @@ func TestPeerACLFiltering(t *testing.T) { }, } - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, false) require.NoError(t, err) require.NotNil(t, manager) @@ -249,7 +249,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { }, } - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, false) require.NoError(tb, err) require.NotNil(tb, manager) require.True(tb, manager.routingEnabled) diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 95f79115a..e13fe8062 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -62,7 +62,7 @@ func TestManagerCreate(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -82,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -117,7 +117,7 @@ func TestManagerDeleteRule(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -210,7 +210,7 @@ func TestAddUDPPacketHook(t *testing.T) { t.Run(tt.name, func(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -263,7 +263,7 @@ func TestManagerReset(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -307,7 +307,7 @@ func TestNotMatchByIP(t *testing.T) { }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -376,7 +376,7 @@ func TestRemovePacketHook(t *testing.T) { } // creating manager instance - manager, err := Create(iface) + manager, err := Create(iface, false) if err != nil { t.Fatalf("Failed to create Manager: %s", err) } @@ -422,7 +422,7 @@ func TestRemovePacketHook(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) require.NoError(t, err) manager.wgNetwork = &net.IPNet{ @@ -508,7 +508,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, false) require.NoError(t, err) time.Sleep(time.Second) @@ -539,7 +539,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) require.NoError(t, err) manager.wgNetwork = &net.IPNet{ diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index d66414af9..d146fef1f 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -52,7 +52,7 @@ func TestDefaultManager(t *testing.T) { ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(ifaceMock, nil) + fw, err := firewall.NewFirewall(ifaceMock, nil, false) if err != nil { t.Errorf("create firewall: %v", err) return @@ -346,7 +346,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(ifaceMock, nil) + fw, err := firewall.NewFirewall(ifaceMock, nil, false) if err != nil { t.Errorf("create firewall: %v", err) return diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index c166820c4..14ff1bb71 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -849,7 +849,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return nil, err } - pf, err := uspfilter.Create(wgIface) + pf, err := uspfilter.Create(wgIface, false) if err != nil { t.Fatalf("failed to create uspfilter: %v", err) return nil, err diff --git a/client/internal/engine.go b/client/internal/engine.go index 8d443c8aa..069c748a5 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -464,7 +464,7 @@ func (e *Engine) createFirewall() error { } var err error - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes) if err != nil || e.firewall == nil { log.Errorf("failed creating firewall manager: %s", err) return nil