Handle disable-server-routes flag in userspace router

This commit is contained in:
Viktor Liu 2025-01-09 14:08:44 +01:00
parent 28f5cd523a
commit daf935942c
9 changed files with 70 additions and 54 deletions

View File

@ -14,13 +14,13 @@ import (
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface)
fm, err := uspfilter.Create(iface, disableServerRoutes)
if err != nil {
return nil, err
}

View File

@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager)
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
if !iface.IsUserspaceBind() {
return fm, err
@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm)
return createUserspaceFirewall(iface, fm, disableServerRoutes)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
} else {
fm, errUsp = uspfilter.Create(iface)
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
}
if errUsp != nil {

View File

@ -87,18 +87,23 @@ type decoder struct {
}
// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper) (*Manager, error) {
return create(iface)
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
return create(iface, disableServerRoutes)
}
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
mgr, err := create(iface)
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
mgr, err := create(iface, disableServerRoutes)
if err != nil {
return nil, err
}
mgr.nativeFirewall = nativeFirewall
if disableServerRoutes {
// skip native vs userspace router logic altogether
return mgr, nil
}
if forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter)); forceUserspaceRouter {
log.Info("userspace routing is forced")
return mgr, nil
@ -125,7 +130,7 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
return mgr, nil
}
func create(iface common.IFaceMapper) (*Manager, error) {
func create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
m := &Manager{
@ -147,6 +152,7 @@ func create(iface common.IFaceMapper) (*Manager, error) {
routeRules: make(map[string]RouteRule),
wgIface: iface,
localipmanager: newLocalIPManager(),
routingEnabled: false,
stateful: !disableConntrack,
// TODO: support changing log level from logrus
logger: nblog.NewFromLogrus(log.StandardLogger()),
@ -166,23 +172,16 @@ func create(iface common.IFaceMapper) (*Manager, error) {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
if disableRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting)); disableRouting {
disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting))
if disableUspRouting || disableServerRoutes {
log.Info("userspace routing is disabled")
return m, nil
}
intf := iface.GetWGDevice()
if intf == nil {
log.Info("forwarding not supported")
// Only supported in userspace mode as we need to inject packets back into wireguard directly
} else {
var err error
m.forwarder, err = forwarder.New(iface, m.logger, m.netstack)
if err != nil {
log.Errorf("failed to create forwarder: %v", err)
} else {
m.routingEnabled = true
}
// netstack needs the forwarder for local traffic
if m.netstack || m.routingEnabled {
m.initForwarder(iface)
}
if err := iface.SetFilter(m); err != nil {
@ -191,6 +190,25 @@ func create(iface common.IFaceMapper) (*Manager, error) {
return m, nil
}
func (m *Manager) initForwarder(iface common.IFaceMapper) {
// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := iface.GetWGDevice()
if intf == nil {
log.Info("forwarding not supported")
m.routingEnabled = false
return
}
forwarder, err := forwarder.New(iface, m.logger, m.netstack)
if err != nil {
log.Errorf("failed to create forwarder: %v", err)
m.routingEnabled = false
return
}
m.forwarder = forwarder
}
func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
@ -509,8 +527,6 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
// dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// TODO: Disable router if --disable-server-router is set
m.mutex.RLock()
defer m.mutex.RUnlock()

View File

@ -158,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -203,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -251,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -450,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -577,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -668,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -787,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
@ -875,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})

View File

@ -34,7 +34,7 @@ func TestPeerACLFiltering(t *testing.T) {
},
}
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, false)
require.NoError(t, err)
require.NotNil(t, manager)
@ -249,7 +249,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
},
}
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, false)
require.NoError(tb, err)
require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled)

View File

@ -62,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@ -82,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
},
}
m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@ -117,7 +117,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@ -210,7 +210,7 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@ -263,7 +263,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@ -307,7 +307,7 @@ func TestNotMatchByIP(t *testing.T) {
},
}
m, err := Create(ifaceMock)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@ -376,7 +376,7 @@ func TestRemovePacketHook(t *testing.T) {
}
// creating manager instance
manager, err := Create(iface)
manager, err := Create(iface, false)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
@ -422,7 +422,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
@ -508,7 +508,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, false)
require.NoError(t, err)
time.Sleep(time.Second)
@ -539,7 +539,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
}, false)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{

View File

@ -52,7 +52,7 @@ func TestDefaultManager(t *testing.T) {
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil)
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
if err != nil {
t.Errorf("create firewall: %v", err)
return
@ -346,7 +346,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil)
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
if err != nil {
t.Errorf("create firewall: %v", err)
return

View File

@ -849,7 +849,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err
}
pf, err := uspfilter.Create(wgIface)
pf, err := uspfilter.Create(wgIface, false)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err

View File

@ -464,7 +464,7 @@ func (e *Engine) createFirewall() error {
}
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil