diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go index b86d16043..f093f3429 100644 --- a/client/firewall/uspfilter/localip.go +++ b/client/firewall/uspfilter/localip.go @@ -14,8 +14,13 @@ import ( type localIPManager struct { mu sync.RWMutex - // Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory) - ipv4Bitmap [1 << 16]uint32 + // fixed-size high array for upper byte of a IPv4 address + ipv4Bitmap [256]*ipv4LowBitmap +} + +// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address +type ipv4LowBitmap struct { + bitmap [8192]uint32 } func newLocalIPManager() *localIPManager { @@ -27,35 +32,59 @@ func (m *localIPManager) setBitmapBit(ip net.IP) { if ipv4 == nil { return } - high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) - low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) - m.ipv4Bitmap[high] |= 1 << (low % 32) + high := uint16(ipv4[0]) + low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) + + index := low / 32 + bit := low % 32 + + if m.ipv4Bitmap[high] == nil { + m.ipv4Bitmap[high] = &ipv4LowBitmap{} + } + + m.ipv4Bitmap[high].bitmap[index] |= 1 << bit } -func (m *localIPManager) checkBitmapBit(ip []byte) bool { - high := (uint16(ip[0]) << 8) | uint16(ip[1]) - low := (uint16(ip[2]) << 8) | uint16(ip[3]) - return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0 -} - -func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { +func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { if ipv4 := ip.To4(); ipv4 != nil { - high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) - low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) - if int(high) >= len(*newIPv4Bitmap) { - return fmt.Errorf("invalid IPv4 address: %s", ip) + high := uint16(ipv4[0]) + low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) + + if bitmap[high] == nil { + bitmap[high] = &ipv4LowBitmap{} } - ipStr := ip.String() + + index := low / 32 + bit := low % 32 + bitmap[high].bitmap[index] |= 1 << bit + + ipStr := ipv4.String() if _, exists := ipv4Set[ipStr]; !exists { ipv4Set[ipStr] = struct{}{} *ipv4Addresses = append(*ipv4Addresses, ipStr) - newIPv4Bitmap[high] |= 1 << (low % 32) } } +} + +func (m *localIPManager) checkBitmapBit(ip []byte) bool { + high := uint16(ip[0]) + low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3]) + + if m.ipv4Bitmap[high] == nil { + return false + } + + index := low / 32 + bit := low % 32 + return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0 +} + +func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { + m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) return nil } -func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { +func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { addrs, err := iface.Addrs() if err != nil { log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) @@ -73,7 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 continue } - if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil { + if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil { log.Debugf("process IP failed: %v", err) } } @@ -86,14 +115,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { } }() - var newIPv4Bitmap [1 << 16]uint32 + var newIPv4Bitmap [256]*ipv4LowBitmap ipv4Set := make(map[string]struct{}) var ipv4Addresses []string // 127.0.0.0/8 - high := uint16(127) << 8 - for i := uint16(0); i < 256; i++ { - newIPv4Bitmap[high|i] = 0xffffffff + newIPv4Bitmap[127] = &ipv4LowBitmap{} + for i := 0; i < 8192; i++ { + newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF } if iface != nil { @@ -120,12 +149,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { } func (m *localIPManager) IsLocalIP(ip netip.Addr) bool { + if !ip.Is4() { + return false + } + m.mu.RLock() defer m.mu.RUnlock() - if ip.Is4() { - return m.checkBitmapBit(ip.AsSlice()) - } - - return false + return m.checkBitmapBit(ip.AsSlice()) } diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go index 0715ddc41..0104c9603 100644 --- a/client/firewall/uspfilter/localip_test.go +++ b/client/firewall/uspfilter/localip_test.go @@ -77,6 +77,18 @@ func TestLocalIPManager(t *testing.T) { testIP: netip.MustParseAddr("192.168.1.2"), expected: false, }, + { + name: "Local IP doesn't match - addresses 32 apart", + setupAddr: wgaddr.Address{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: netip.MustParseAddr("192.168.1.33"), + expected: false, + }, { name: "IPv6 address", setupAddr: wgaddr.Address{ @@ -192,10 +204,8 @@ func BenchmarkIPChecks(b *testing.B) { interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i)) } - // Setup bitmap version - bitmapManager := &localIPManager{ - ipv4Bitmap: [1 << 16]uint32{}, - } + // Setup bitmap + bitmapManager := newLocalIPManager() for _, ip := range interfaces[:8] { // Add half of IPs bitmapManager.setBitmapBit(ip) } @@ -248,7 +258,7 @@ func BenchmarkWGPosition(b *testing.B) { // Create two managers - one checks WG IP first, other checks it last b.Run("WG_First", func(b *testing.B) { - bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} + bm := newLocalIPManager() bm.setBitmapBit(wgIP) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -257,7 +267,7 @@ func BenchmarkWGPosition(b *testing.B) { }) b.Run("WG_Last", func(b *testing.B) { - bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} + bm := newLocalIPManager() // Fill with other IPs first for i := 0; i < 15; i++ { bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))