diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go new file mode 100644 index 000000000..9636e1cc9 --- /dev/null +++ b/client/firewall/uspfilter/localip.go @@ -0,0 +1,128 @@ +package uspfilter + +import ( + "fmt" + "net" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" +) + +type localIPManager struct { + mu sync.RWMutex + + // Use bitmap for IPv4 (32 bits * 2^16 = 8KB memory) + ipv4Bitmap [1 << 16]uint32 +} + +func newLocalIPManager() *localIPManager { + return &localIPManager{} +} + +func (m *localIPManager) setBitmapBit(ip net.IP) { + ipv4 := ip.To4() + if ipv4 == nil { + return + } + high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) + low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) + m.ipv4Bitmap[high] |= 1 << (low % 32) +} + +func (m *localIPManager) checkBitmapBit(ip net.IP) bool { + ipv4 := ip.To4() + if ipv4 == nil { + return false + } + high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) + low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) + return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0 +} + +func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { + if ipv4 := ip.To4(); ipv4 != nil { + high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) + low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) + if int(high) >= len(*newIPv4Bitmap) { + return fmt.Errorf("invalid IPv4 address: %s", ip) + } + ipStr := ip.String() + if _, exists := ipv4Set[ipStr]; !exists { + ipv4Set[ipStr] = struct{}{} + *ipv4Addresses = append(*ipv4Addresses, ipStr) + (*newIPv4Bitmap)[high] |= 1 << (low % 32) + } + } + return nil +} + +func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { + addrs, err := iface.Addrs() + if err != nil { + log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) + return + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + default: + continue + } + + if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil { + log.Debugf("process IP failed: %v", err) + } + } +} + +func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic: %v", r) + } + }() + + interfaces, err := net.Interfaces() + if err != nil { + return fmt.Errorf("get interfaces: %w", err) + } + + var newIPv4Bitmap [1 << 16]uint32 + ipv4Set := make(map[string]struct{}) + var ipv4Addresses []string + + if iface != nil { + if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil { + return err + } + } + + for _, intf := range interfaces { + m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses) + } + + m.mu.Lock() + m.ipv4Bitmap = newIPv4Bitmap + m.mu.Unlock() + + log.Debugf("Local IPv4 addresses: %v", ipv4Addresses) + return nil +} + +func (m *localIPManager) IsLocalIP(ip net.IP) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + if ipv4 := ip.To4(); ipv4 != nil { + return m.checkBitmapBit(ipv4) + } + + return false +} diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go new file mode 100644 index 000000000..f179fb2e3 --- /dev/null +++ b/client/firewall/uspfilter/localip_test.go @@ -0,0 +1,93 @@ +package uspfilter + +import ( + "net" + "testing" +) + +// MapImplementation is a version using map[string]struct{} +type MapImplementation struct { + localIPs map[string]struct{} +} + +func BenchmarkIPChecks(b *testing.B) { + interfaces := make([]net.IP, 16) + for i := range interfaces { + interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i)) + } + + // Setup bitmap version + bitmapManager := &localIPManager{ + ipv4Bitmap: [1 << 16]uint32{}, + } + for _, ip := range interfaces[:8] { // Add half of IPs + bitmapManager.setBitmapBit(ip) + } + + // Setup map version + mapManager := &MapImplementation{ + localIPs: make(map[string]struct{}), + } + for _, ip := range interfaces[:8] { + mapManager.localIPs[ip.String()] = struct{}{} + } + + b.Run("Bitmap_Hit", func(b *testing.B) { + ip := interfaces[4] + b.ResetTimer() + for i := 0; i < b.N; i++ { + bitmapManager.checkBitmapBit(ip) + } + }) + + b.Run("Bitmap_Miss", func(b *testing.B) { + ip := interfaces[12] + b.ResetTimer() + for i := 0; i < b.N; i++ { + bitmapManager.checkBitmapBit(ip) + } + }) + + b.Run("Map_Hit", func(b *testing.B) { + ip := interfaces[4] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = mapManager.localIPs[ip.String()] + } + }) + + b.Run("Map_Miss", func(b *testing.B) { + ip := interfaces[12] + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = mapManager.localIPs[ip.String()] + } + }) +} + +func BenchmarkWGPosition(b *testing.B) { + wgIP := net.ParseIP("10.10.0.1") + + // Create two managers - one checks WG IP first, other checks it last + b.Run("WG_First", func(b *testing.B) { + bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} + bm.setBitmapBit(wgIP) + b.ResetTimer() + for i := 0; i < b.N; i++ { + bm.checkBitmapBit(wgIP) + } + }) + + b.Run("WG_Last", func(b *testing.B) { + bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} + // Fill with other IPs first + for i := 0; i < 15; i++ { + bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i))) + } + bm.setBitmapBit(wgIP) // Add WG IP last + b.ResetTimer() + for i := 0; i < b.N; i++ { + bm.checkBitmapBit(wgIP) + } + }) +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index ba7b2e8b5..99035c4a3 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -55,8 +55,11 @@ type Manager struct { routingEnabled bool // indicates whether we leave forwarding and filtering to the native firewall nativeRouter bool + // indicates whether we track outbound connections + stateful bool + + localipmanager *localIPManager - stateful bool udpTracker *conntrack.UDPTracker icmpTracker *conntrack.ICMPTracker tcpTracker *conntrack.TCPTracker @@ -120,15 +123,20 @@ func create(iface common.IFaceMapper) (*Manager, error) { return d }, }, - outgoingRules: make(map[string]RuleSet), - incomingRules: make(map[string]RuleSet), - routeRules: make(map[string]RouteRule), - wgIface: iface, - stateful: !disableConntrack, + outgoingRules: make(map[string]RuleSet), + incomingRules: make(map[string]RuleSet), + routeRules: make(map[string]RouteRule), + wgIface: iface, + localipmanager: newLocalIPManager(), + stateful: !disableConntrack, // TODO: support changing log level from logrus logger: nblog.NewFromLogrus(log.StandardLogger()), } + if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { + return nil, fmt.Errorf("update local IPs: %w", err) + } + // Only initialize trackers if stateful mode is enabled if disableConntrack { log.Info("conntrack is disabled") @@ -346,9 +354,9 @@ func (m *Manager) DropIncoming(packetData []byte) bool { return m.dropFilter(packetData, m.incomingRules) } -func (m *Manager) isLocalIP(ip net.IP) bool { - // TODO: add other interface IPs and keep track of them - return ip.Equal(m.wgIface.Address().IP) +// UpdateLocalIPs updates the list of local IPs +func (m *Manager) UpdateLocalIPs() error { + return m.localipmanager.UpdateLocalIPs(m.wgIface) } func (m *Manager) processOutgoingHooks(packetData []byte) bool { @@ -496,7 +504,7 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { } // Handle local traffic - apply peer ACLs - if m.isLocalIP(dstIP) { + if m.localipmanager.IsLocalIP(dstIP) { drop := m.applyRules(srcIP, packetData, rules, d) if drop { m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied", diff --git a/client/internal/engine.go b/client/internal/engine.go index 042d384dc..584d41195 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -40,13 +40,13 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/management/domain" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" mgm "github.com/netbirdio/netbird/management/client" - "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" auth "github.com/netbirdio/netbird/relay/auth/hmac" relayClient "github.com/netbirdio/netbird/relay/client" @@ -186,6 +186,10 @@ type Peer struct { WgAllowedIps string } +type localIpUpdater interface { + UpdateLocalIPs() error +} + // NewEngine creates a new Connection Engine func NewEngine( clientCtx context.Context, @@ -802,6 +806,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } + if e.firewall != nil { + if localipfw, ok := e.firewall.(localIpUpdater); ok { + if err := localipfw.UpdateLocalIPs(); err != nil { + log.Errorf("failed to update local IPs: %v", err) + } + } + } + // DNS forwarder dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())