mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 10:18:50 +02:00
Handle all local IPs
This commit is contained in:
parent
ed22d79f04
commit
a12a9ac290
128
client/firewall/uspfilter/localip.go
Normal file
128
client/firewall/uspfilter/localip.go
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type localIPManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Use bitmap for IPv4 (32 bits * 2^16 = 8KB memory)
|
||||||
|
ipv4Bitmap [1 << 16]uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLocalIPManager() *localIPManager {
|
||||||
|
return &localIPManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||||
|
ipv4 := ip.To4()
|
||||||
|
if ipv4 == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||||
|
ipv4 := ip.To4()
|
||||||
|
if ipv4 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
if int(high) >= len(*newIPv4Bitmap) {
|
||||||
|
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||||
|
}
|
||||||
|
ipStr := ip.String()
|
||||||
|
if _, exists := ipv4Set[ipStr]; !exists {
|
||||||
|
ipv4Set[ipStr] = struct{}{}
|
||||||
|
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||||
|
(*newIPv4Bitmap)[high] |= 1 << (low % 32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
|
log.Debugf("process IP failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get interfaces: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var newIPv4Bitmap [1 << 16]uint32
|
||||||
|
ipv4Set := make(map[string]struct{})
|
||||||
|
var ipv4Addresses []string
|
||||||
|
|
||||||
|
if iface != nil {
|
||||||
|
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, intf := range interfaces {
|
||||||
|
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.ipv4Bitmap = newIPv4Bitmap
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
return m.checkBitmapBit(ipv4)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
93
client/firewall/uspfilter/localip_test.go
Normal file
93
client/firewall/uspfilter/localip_test.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MapImplementation is a version using map[string]struct{}
|
||||||
|
type MapImplementation struct {
|
||||||
|
localIPs map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIPChecks(b *testing.B) {
|
||||||
|
interfaces := make([]net.IP, 16)
|
||||||
|
for i := range interfaces {
|
||||||
|
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup bitmap version
|
||||||
|
bitmapManager := &localIPManager{
|
||||||
|
ipv4Bitmap: [1 << 16]uint32{},
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||||
|
bitmapManager.setBitmapBit(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup map version
|
||||||
|
mapManager := &MapImplementation{
|
||||||
|
localIPs: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] {
|
||||||
|
mapManager.localIPs[ip.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWGPosition(b *testing.B) {
|
||||||
|
wgIP := net.ParseIP("10.10.0.1")
|
||||||
|
|
||||||
|
// Create two managers - one checks WG IP first, other checks it last
|
||||||
|
b.Run("WG_First", func(b *testing.B) {
|
||||||
|
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||||
|
bm.setBitmapBit(wgIP)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("WG_Last", func(b *testing.B) {
|
||||||
|
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||||
|
// Fill with other IPs first
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||||
|
}
|
||||||
|
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
@ -55,8 +55,11 @@ type Manager struct {
|
|||||||
routingEnabled bool
|
routingEnabled bool
|
||||||
// indicates whether we leave forwarding and filtering to the native firewall
|
// indicates whether we leave forwarding and filtering to the native firewall
|
||||||
nativeRouter bool
|
nativeRouter bool
|
||||||
|
// indicates whether we track outbound connections
|
||||||
|
stateful bool
|
||||||
|
|
||||||
|
localipmanager *localIPManager
|
||||||
|
|
||||||
stateful bool
|
|
||||||
udpTracker *conntrack.UDPTracker
|
udpTracker *conntrack.UDPTracker
|
||||||
icmpTracker *conntrack.ICMPTracker
|
icmpTracker *conntrack.ICMPTracker
|
||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
@ -120,15 +123,20 @@ func create(iface common.IFaceMapper) (*Manager, error) {
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[string]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[string]RuleSet),
|
||||||
routeRules: make(map[string]RouteRule),
|
routeRules: make(map[string]RouteRule),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
stateful: !disableConntrack,
|
localipmanager: newLocalIPManager(),
|
||||||
|
stateful: !disableConntrack,
|
||||||
// TODO: support changing log level from logrus
|
// TODO: support changing log level from logrus
|
||||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Only initialize trackers if stateful mode is enabled
|
// Only initialize trackers if stateful mode is enabled
|
||||||
if disableConntrack {
|
if disableConntrack {
|
||||||
log.Info("conntrack is disabled")
|
log.Info("conntrack is disabled")
|
||||||
@ -346,9 +354,9 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
|||||||
return m.dropFilter(packetData, m.incomingRules)
|
return m.dropFilter(packetData, m.incomingRules)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isLocalIP(ip net.IP) bool {
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
// TODO: add other interface IPs and keep track of them
|
func (m *Manager) UpdateLocalIPs() error {
|
||||||
return ip.Equal(m.wgIface.Address().IP)
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
||||||
@ -496,7 +504,7 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle local traffic - apply peer ACLs
|
// Handle local traffic - apply peer ACLs
|
||||||
if m.isLocalIP(dstIP) {
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
drop := m.applyRules(srcIP, packetData, rules, d)
|
drop := m.applyRules(srcIP, packetData, rules, d)
|
||||||
if drop {
|
if drop {
|
||||||
m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied",
|
m.logger.Trace("Dropping local packet: src=%s dst=%s rules=denied",
|
||||||
|
@ -40,13 +40,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
@ -186,6 +186,10 @@ type Peer struct {
|
|||||||
WgAllowedIps string
|
WgAllowedIps string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type localIpUpdater interface {
|
||||||
|
UpdateLocalIPs() error
|
||||||
|
}
|
||||||
|
|
||||||
// NewEngine creates a new Connection Engine
|
// NewEngine creates a new Connection Engine
|
||||||
func NewEngine(
|
func NewEngine(
|
||||||
clientCtx context.Context,
|
clientCtx context.Context,
|
||||||
@ -802,6 +806,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.acl.ApplyFiltering(networkMap)
|
e.acl.ApplyFiltering(networkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.firewall != nil {
|
||||||
|
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
||||||
|
if err := localipfw.UpdateLocalIPs(); err != nil {
|
||||||
|
log.Errorf("failed to update local IPs: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DNS forwarder
|
// DNS forwarder
|
||||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user