[client] Use platform-native routing APIs for freeBSD, macOS and Windows

This commit is contained in:
Viktor Liu 2025-06-04 16:28:58 +02:00 committed by GitHub
parent 87148c503f
commit ea4d13e96d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
53 changed files with 1552 additions and 1046 deletions

View File

@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3),
RunE: tracePacket,

View File

@ -2,7 +2,7 @@ package iptables
import (
"fmt"
"net"
"net/netip"
"testing"
"time"
@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
}
},
}
@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
err = manager.Close(nil)
@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
}
},
}
@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
Values: []uint16{443},
}
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
}
},
}
@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
require.NoError(t, err)
ip := net.ParseIP("10.20.0.100")
ip := netip.MustParseAddr("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
}

View File

@ -3,7 +3,6 @@ package nftables
import (
"bytes"
"fmt"
"net"
"net/netip"
"os/exec"
"testing"
@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
}
},
}
@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
ip := net.ParseIP("100.96.0.1")
ip := netip.MustParseAddr("100.96.0.1").Unmap()
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
}
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{
&expr.Payload{
DestRegister: 1,
@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: add.AsSlice(),
Data: ip.AsSlice(),
},
&expr.Payload{
DestRegister: 1,
@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
}
},
}
@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second)
}()
ip := net.ParseIP("10.20.0.100")
ip := netip.MustParseAddr("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(

View File

@ -41,7 +41,7 @@ type Forwarder struct {
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip net.IP
ip tcpip.Address
netstack bool
}
@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: ones,
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
PrefixLen: iface.Address().Network.Bits(),
},
}
@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: iface.Address().IP,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
}
receiveWindow := defaultReceiveWindow
@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
}
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {

View File

@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
if ipv4 := ip.To4(); ipv4 != nil {
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
ipStr := ipv4.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
}
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
}
}
@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue
}
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}()
var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []netip.Addr
// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}

View File

@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Localhost range",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("127.0.0.2"),
expected: true,
@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Localhost standard address",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("127.0.0.1"),
expected: true,
@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Localhost range edge",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("127.255.255.255"),
expected: true,
@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Local IP matches",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("192.168.1.1"),
expected: true,
@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Local IP doesn't match",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
{
name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("192.168.1.33"),
expected: false,
@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
IP: netip.MustParseAddr("fe80::1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("fe80::1"),
expected: false,

View File

@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
}
},
}

View File

@ -71,7 +71,6 @@ type Manager struct {
// incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
@ -1091,11 +1090,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return true
}
// SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not

View File

@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup
sc.setupFunc(manager)
@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table
srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count)
@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP,
state: "post_handshake",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
}
for _, r := range rules {
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
dst := fw.Network{Prefix: r.dest}
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil {
b.Fatal(err)
}

View File

@ -19,12 +19,8 @@ import (
)
func TestPeerACLFiltering(t *testing.T) {
localIP := net.ParseIP("100.10.0.100")
wgNet := &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
localIP := netip.MustParseAddr("100.10.0.100")
wgNet := netip.MustParsePrefix("100.10.0.0/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
require.NoError(t, manager.Close(nil))
})
manager.wgNetwork = wgNet
err = manager.UpdateLocalIPs()
require.NoError(t, err)
@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
localIP, wgNet, err := net.ParseCIDR(network)
require.NoError(tb, err)
wgNet := netip.MustParsePrefix(network)
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: localIP,
IP: wgNet.Addr(),
Network: wgNet,
}
},
@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
}
},
}

View File

@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
}
},
}
@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to create Manager: %v", err)
return
}
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP
@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
}, false, flowLogger)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() {
@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, false, flowLogger)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{

View File

@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil
}
if u.address.Network.Contains(a.AsSlice()) {
if u.address.Network.Contains(a) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}

View File

@ -1,7 +1,6 @@
package device
import (
"net"
"net/netip"
"sync"
@ -24,9 +23,6 @@ type PacketFilter interface {
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
}
// FilteredDevice to override Read or Write of packets

View File

@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
if err != nil {
return nil, fmt.Errorf("last ip: %w", err)
}
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr)

View File

@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
}
ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)

View File

@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
}
w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter)
return nil

View File

@ -5,7 +5,6 @@
package mocks
import (
net "net"
"net/netip"
reflect "reflect"
@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@ -1,8 +1,6 @@
package netstack
import (
"fmt"
"net"
"net/netip"
"os"
"strconv"
@ -15,8 +13,8 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive
address net.IP
dnsAddress net.IP
address netip.Addr
dnsAddress netip.Addr
mtu int
listenAddress string
@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device
}
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
return &NetStackTun{
address: address,
dnsAddress: dnsAddress,
@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
}
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{addr.Unmap()},
[]netip.Addr{dnsAddr.Unmap()},
[]netip.Addr{t.address},
[]netip.Addr{t.dnsAddress},
t.mtu)
if err != nil {
return nil, nil, err

View File

@ -2,28 +2,27 @@ package wgaddr
import (
"fmt"
"net"
"net/netip"
)
// Address WireGuard parsed address
type Address struct {
IP net.IP
Network *net.IPNet
IP netip.Addr
Network netip.Prefix
}
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address)
prefix, err := netip.ParsePrefix(address)
if err != nil {
return Address{}, err
}
return Address{
IP: ip,
Network: network,
IP: prefix.Addr().Unmap(),
Network: prefix.Masked(),
}, nil
}
func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
}

View File

@ -1,7 +1,7 @@
package acl
import (
"net"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
@ -43,12 +43,11 @@ func TestDefaultManager(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32")
require.NoError(t, err)
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip,
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
@ -162,12 +161,11 @@ func TestDefaultManagerStateless(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32")
require.NoError(t, err)
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip,
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
@ -372,12 +370,11 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32")
require.NoError(t, err)
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip,
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()

View File

@ -2,7 +2,7 @@ package internal
import (
"fmt"
"net"
"net/netip"
"slices"
"strings"
@ -12,13 +12,14 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData)
if ip == nil || ip.To4() == nil {
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip, err := netip.ParseAddr(aRecord.RData)
if err != nil {
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
return nbdns.SimpleRecord{}, false
}
if !ipNet.Contains(ip) {
if !prefix.Contains(ip) {
return nbdns.SimpleRecord{}, false
}
@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
}
// generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
networkIP := ipNet.IP.Mask(ipNet.Mask)
maskOnes, _ := ipNet.Mask.Size()
func generateReverseZoneName(network netip.Prefix) (string, error) {
networkIP := network.Masked().Addr()
if !networkIP.Is4() {
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
}
// round up to nearest byte
octetsToUse := (maskOnes + 7) / 8
octetsToUse := (network.Bits() + 7) / 8
octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
}
reverseOctets := make([]string, octetsToUse)
@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
}
// collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
continue
}
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
if ptrRecord, ok := createPTRRecord(record, prefix); ok {
records = append(records, ptrRecord)
}
}
@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
}
// addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
zoneName, err := generateReverseZoneName(ipNet)
func addReverseZone(config *nbdns.Config, network netip.Prefix) {
zoneName, err := generateReverseZoneName(network)
if err != nil {
log.Warn(err)
return
@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
return
}
records := collectPTRRecords(config, ipNet)
records := collectPTRRecords(config, network)
reverseZone := nbdns.CustomZone{
Domain: zoneName,

View File

@ -46,10 +46,9 @@ func (w *mocWGIface) Name() string {
}
func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return wgaddr.Address{
IP: ip,
Network: network,
IP: netip.MustParseAddr("100.66.100.1"),
Network: netip.MustParsePrefix("100.66.100.0/24"),
}
}
@ -464,17 +463,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)

View File

@ -24,11 +24,15 @@ type ServiceViaMemory struct {
}
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(),
runtimeIP: lastIP.String(),
runtimePort: defaultPort,
}
return s
@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
if s.wgInterface.Address().IP.Is6() {
firstLayerDecoder = layers.LayerTypeIPv6
}

View File

@ -1,33 +0,0 @@
package dns
import (
"net"
"testing"
nbnet "github.com/netbirdio/netbird/util/net"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@ -3,6 +3,7 @@ package dns
import (
"context"
"net"
"net/netip"
"syscall"
"time"
@ -23,8 +24,8 @@ type upstreamResolver struct {
func newUpstreamResolver(
ctx context.Context,
_ string,
_ net.IP,
_ *net.IPNet,
_ netip.Addr,
_ netip.Prefix,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,

View File

@ -4,7 +4,7 @@ package dns
import (
"context"
"net"
"net/netip"
"time"
"github.com/miekg/dns"
@ -19,8 +19,8 @@ type upstreamResolver struct {
func newUpstreamResolver(
ctx context.Context,
_ string,
_ net.IP,
_ *net.IPNet,
_ netip.Addr,
_ netip.Prefix,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,

View File

@ -6,6 +6,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"syscall"
"time"
@ -18,16 +19,16 @@ import (
type upstreamResolverIOS struct {
*upstreamResolverBase
lIP net.IP
lNet *net.IPNet
lIP netip.Addr
lNet netip.Prefix
interfaceName string
}
func newUpstreamResolver(
ctx context.Context,
interfaceName string,
ip net.IP,
net *net.IPNet,
ip netip.Addr,
net netip.Prefix,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
@ -58,8 +59,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
}
client.DialTimeout = timeout
upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
upstreamIP, err := netip.ParseAddr(upstreamHost)
if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
}
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
@ -73,7 +77,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
@ -82,7 +86,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: ip,
IP: ip.AsSlice(),
Port: 0, // Let the OS pick a free port
},
Timeout: dialTimeout,

View File

@ -2,7 +2,7 @@ package dns
import (
"context"
"net"
"net/netip"
"strings"
"testing"
"time"
@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {

View File

@ -1008,7 +1008,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
log.Errorf("failed to update routes: %v", err)
}
if e.acl != nil {
@ -1104,7 +1104,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID),
Network: prefix,
Network: prefix.Masked(),
Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType),
@ -1138,7 +1138,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
return entries
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0),
@ -1790,9 +1790,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
}
// GetWgAddr returns the wireguard address
func (e *Engine) GetWgAddr() net.IP {
func (e *Engine) GetWgAddr() netip.Addr {
if e.wgInterface == nil {
return nil
return netip.Addr{}
}
return e.wgInterface.Address().IP
}
@ -1861,12 +1861,7 @@ func (e *Engine) Address() (netip.Addr, error) {
return netip.Addr{}, errors.New("wireguard interface not initialized")
}
addr := e.wgInterface.Address()
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
}
return ip.Unmap(), nil
return e.wgInterface.Address().IP, nil
}
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {

View File

@ -371,11 +371,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
}
},
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {

View File

@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
// fallback if mark rules are not in place
wgnet := c.iface.Address().Network
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice())
return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
}
// mapRxPackets maps packet counts to RX based on flow direction
@ -293,17 +293,15 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
// fallback if marks are not set
wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
switch {
case wgaddr.Equal(src):
case wgaddr == srcIP:
return nftypes.Egress
case wgaddr.Equal(dst):
case wgaddr == dstIP:
return nftypes.Ingress
case wgnetwork.Contains(src):
case wgnetwork.Contains(srcIP):
// netbird network -> resource network
return nftypes.Ingress
case wgnetwork.Contains(dst):
case wgnetwork.Contains(dstIP):
// resource network -> netbird network
return nftypes.Egress
}

View File

@ -2,7 +2,7 @@ package logger
import (
"context"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
@ -23,17 +23,16 @@ type Logger struct {
rcvChan atomic.Pointer[rcvChan]
cancel context.CancelFunc
statusRecorder *peer.Status
wgIfaceIPNet net.IPNet
wgIfaceNet netip.Prefix
dnsCollection atomic.Bool
exitNodeCollection atomic.Bool
Store types.Store
}
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
return &Logger{
statusRecorder: statusRecorder,
wgIfaceIPNet: wgIfaceIPNet,
wgIfaceNet: wgIfaceIPNet,
Store: store.NewMemoryStore(),
}
}
@ -89,11 +88,11 @@ func (l *Logger) startReceiver() {
var isSrcExitNode bool
var isDestExitNode bool
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
if !l.wgIfaceNet.Contains(event.SourceIP) {
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
}
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
if !l.wgIfaceNet.Contains(event.DestIP) {
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
}

View File

@ -1,7 +1,7 @@
package logger_test
import (
"net"
"net/netip"
"testing"
"time"
@ -12,7 +12,7 @@ import (
)
func TestStore(t *testing.T) {
logger := logger.New(nil, net.IPNet{})
logger := logger.New(nil, netip.Prefix{})
logger.Enable()
event := types.EventFields{

View File

@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"time"
@ -34,11 +34,11 @@ type Manager struct {
// NewManager creates a new netflow manager
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
var ipNet net.IPNet
var prefix netip.Prefix
if iface != nil {
ipNet = *iface.Address().Network
prefix = iface.Address().Network
}
flowLogger := logger.New(statusRecorder, ipNet)
flowLogger := logger.New(statusRecorder, prefix)
var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {

View File

@ -1,7 +1,7 @@
package netflow
import (
"net"
"net/netip"
"testing"
"time"
@ -33,10 +33,7 @@ func (m *mockIFaceMapper) IsUserspaceBind() bool {
func TestManager_Update(t *testing.T) {
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
Network: netip.MustParsePrefix("192.168.1.1/32"),
},
isUserspaceBind: true,
}
@ -102,10 +99,7 @@ func TestManager_Update(t *testing.T) {
func TestManager_Update_TokenPreservation(t *testing.T) {
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
Network: netip.MustParsePrefix("192.168.1.1/32"),
},
isUserspaceBind: true,
}

View File

@ -264,7 +264,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
continue
}
prefix := netip.PrefixFrom(ip, ip.BitLen())
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
newPrefixes = append(newPrefixes, prefix)
}

View File

@ -333,11 +333,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
var merr *multierror.Error
if !m.disableClientRoutes {
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
log.Errorf("Failed to update system routes: %v", err)
merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err))
}
m.updateClientNetworks(updateSerial, filteredClientRoutes)
@ -346,14 +347,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
m.clientRoutes = newClientRoutesIDMap
if m.serverRouter == nil {
return nil
return nberrors.FormatErrorOrNil(merr)
}
if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
return fmt.Errorf("update routes: %w", err)
merr = multierror.Append(merr, fmt.Errorf("update server routes: %w", err))
}
return nil
return nberrors.FormatErrorOrNil(merr)
}
// SetRouteChangeListener set RouteListener for route change Notifier

View File

@ -44,7 +44,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -71,7 +71,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.252.250/30"),
Network: netip.MustParsePrefix("100.64.252.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -99,7 +99,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"),
Network: netip.MustParsePrefix("100.64.30.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -127,7 +127,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"),
Network: netip.MustParsePrefix("100.64.30.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -211,7 +211,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -233,7 +233,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -250,7 +250,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -272,7 +272,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -282,7 +282,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "b",
NetID: "routeA",
Peer: remotePeerKey2,
Network: netip.MustParsePrefix("100.64.251.250/30"),
Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -299,7 +299,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -327,7 +327,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.251.250/30"),
Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -356,7 +356,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "l1",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.251.250/30"),
Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -376,7 +376,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "r1",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
@ -440,11 +440,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
}
if len(testCase.inputInitRoutes) > 0 {
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
require.NoError(t, err, "should update routes with init routes")
}
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected

View File

@ -13,7 +13,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
const (
@ -22,8 +22,13 @@ const (
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
type iface interface {
Address() wgaddr.Address
Name() string
}
// Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface iface.WGIface) (map[string]int, error) {
func Setup(wgIface iface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error

View File

@ -6,9 +6,10 @@ import (
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type Nexthop struct {
@ -30,11 +31,16 @@ func (n Nexthop) String() string {
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
}
type wgIface interface {
Address() wgaddr.Address
Name() string
}
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface iface.WGIface
wgInterface wgIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
@ -45,9 +51,27 @@ type SysOps struct {
notifier *notifier.Notifier
}
func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps {
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
}
}
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
addr := prefix.Addr()
switch {
case
!addr.IsValid(),
addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsMulticast(),
addr.IsUnspecified() && prefix.Bits() != 0,
r.wgInterface.Address().Network.Contains(addr):
return vars.ErrRouteNotAllowed
}
return nil
}

View File

@ -8,6 +8,8 @@ import (
"net/netip"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
@ -33,7 +35,12 @@ func init() {
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := NewSysOps(nil, nil)
@ -43,7 +50,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
@ -59,7 +66,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
@ -119,18 +126,39 @@ func TestBits(t *testing.T) {
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
if runtime.GOOS == "darwin" {
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := NewSysOps(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove route from table")
})
return "lo0"
return intf
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper()
var originalNexthop net.IP
@ -176,12 +204,40 @@ func fetchOriginalGateway() (net.IP, error) {
return net.ParseIP(matches[1]), nil
}
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
}

View File

@ -17,7 +17,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
@ -106,59 +105,15 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
return nil
}
// TODO: fix: for default our wg address now appears as the default gw
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
nexthop, err := GetNextHop(addr)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(nexthop.IP) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
if nexthop.IP.Is6() {
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
nexthop, err = GetNextHop(nexthop.IP)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
return r.addToRouteTable(gatewayPrefix, nexthop)
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, initialNextHop Nexthop) (Nexthop, error) {
if err := r.validateRoute(prefix); err != nil {
return Nexthop{}, err
}
addr := prefix.Addr()
if addr.IsUnspecified() {
return Nexthop{}, vars.ErrRouteNotAllowed
}
@ -179,10 +134,7 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface
Intf: nexthop.Intf,
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
vpnAddr := vpnIntf.Address().IP
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
@ -271,32 +223,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
return nil
}
return r.addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
return r.addToRouteTable(prefix, nextHop)
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
@ -408,12 +335,8 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if runtime.GOOS == "freebsd" {
return Nexthop{Intf: intf}, nil
}
if preferredSrc == nil {
return Nexthop{}, vars.ErrRouteNotFound
return Nexthop{Intf: intf}, nil
}
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
@ -457,32 +380,6 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting()

View File

@ -3,23 +3,25 @@
package systemops
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"syscall"
"testing"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type dialer interface {
@ -27,105 +29,370 @@ type dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func TestAddRemoveRoutes(t *testing.T) {
func TestAddVPNRoute(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
shouldRouteToWireguard bool
shouldBeRemoved bool
name string
prefix netip.Prefix
expectError bool
}{
{
name: "Should Add And Remove Route 100.66.120.0/24",
prefix: netip.MustParsePrefix("100.66.120.0/24"),
shouldRouteToWireguard: true,
shouldBeRemoved: true,
name: "IPv4 - Private network route",
prefix: netip.MustParsePrefix("10.10.100.0/24"),
},
{
name: "Should Not Add Or Remove Route 127.0.0.1/32",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
shouldRouteToWireguard: false,
shouldBeRemoved: false,
name: "IPv4 Single host",
prefix: netip.MustParsePrefix("10.111.111.111/32"),
},
{
name: "IPv4 RFC3927 test range",
prefix: netip.MustParsePrefix("198.51.100.0/24"),
},
{
name: "IPv4 Default route",
prefix: netip.MustParsePrefix("0.0.0.0/0"),
},
{
name: "IPv6 Subnet",
prefix: netip.MustParsePrefix("fdb1:848a:7e16::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("fdb1:848a:7e16:a::b/128"),
},
{
name: "IPv6 Default route",
prefix: netip.MustParsePrefix("::/0"),
},
// IPv4 addresses that should be rejected (matches validateRoute logic)
{
name: "IPv4 Loopback",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
expectError: true,
},
{
name: "IPv4 Link-local unicast",
prefix: netip.MustParsePrefix("169.254.1.1/32"),
expectError: true,
},
{
name: "IPv4 Link-local multicast",
prefix: netip.MustParsePrefix("224.0.0.251/32"),
expectError: true,
},
{
name: "IPv4 Multicast",
prefix: netip.MustParsePrefix("239.255.255.250/32"),
expectError: true,
},
{
name: "IPv4 Unspecified with prefix",
prefix: netip.MustParsePrefix("0.0.0.0/32"),
expectError: true,
},
// IPv6 addresses that should be rejected (matches validateRoute logic)
{
name: "IPv6 Loopback",
prefix: netip.MustParsePrefix("::1/128"),
expectError: true,
},
{
name: "IPv6 Link-local unicast",
prefix: netip.MustParsePrefix("fe80::1/128"),
expectError: true,
},
{
name: "IPv6 Link-local multicast",
prefix: netip.MustParsePrefix("ff02::1/128"),
expectError: true,
},
{
name: "IPv6 Interface-local multicast",
prefix: netip.MustParsePrefix("ff01::1/128"),
expectError: true,
},
{
name: "IPv6 Multicast",
prefix: netip.MustParsePrefix("ff00::1/128"),
expectError: true,
},
{
name: "IPv6 Unspecified with prefix",
prefix: netip.MustParsePrefix("::/128"),
expectError: true,
},
{
name: "IPv4 WireGuard interface network overlap",
prefix: netip.MustParsePrefix("100.65.75.0/24"),
expectError: true,
},
{
name: "IPv4 WireGuard interface network subnet",
prefix: netip.MustParsePrefix("100.65.75.0/32"),
expectError: true,
},
}
for n, testCase := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", testCase.name, " on freebsd")
}
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun53%d", n),
Address: "100.65.75.2/24",
WGPrivKey: peerPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgInterface, err := iface.NewWGIFace(opts)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil)
_, _, err = r.SetupRouting(nil, nil)
_, _, err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
})
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
intf, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err)
// add the route
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.expectError {
assert.ErrorIs(t, err, vars.ErrRouteNotAllowed)
return
}
if testCase.shouldRouteToWireguard {
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
// validate it's pointing to the WireGuard interface
require.NoError(t, err)
nextHop := getNextHop(t, testCase.prefix.Addr())
assert.Equal(t, wgInterface.Name(), nextHop.Intf.Name, "next hop interface should be WireGuard interface")
// remove route again
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err)
// validate it's gone
nextHop, err = GetNextHop(testCase.prefix.Addr())
require.True(t,
errors.Is(err, vars.ErrRouteNotFound) || err == nil && nextHop.Intf != nil && nextHop.Intf.Name != wgInterface.Name(),
"err: %v, next hop: %v", err, nextHop)
})
}
}
func getNextHop(t *testing.T, addr netip.Addr) Nexthop {
t.Helper()
if runtime.GOOS == "windows" || runtime.GOOS == "linux" {
nextHop, err := GetNextHop(addr)
if runtime.GOOS == "windows" && errors.Is(err, vars.ErrRouteNotFound) && addr.Is6() {
// TODO: Fix this test. It doesn't return the route when running in a windows github runner, but it is
// present in the route table.
t.Skip("Skipping windows test")
}
require.NoError(t, err)
require.NotNil(t, nextHop.Intf, "next hop interface should not be nil for %s", addr)
return nextHop
}
// GetNextHop for bsd is buggy and returns the wrong interface for the default route.
if addr.IsUnspecified() {
// On macOS, querying 0.0.0.0 returns the wrong interface
if addr.Is4() {
addr = netip.MustParseAddr("1.2.3.4")
} else {
addr = netip.MustParseAddr("2001:db8::1")
}
}
cmd := exec.Command("route", "-n", "get", addr.String())
if addr.Is6() {
cmd = exec.Command("route", "-n", "get", "-inet6", addr.String())
}
output, err := cmd.CombinedOutput()
t.Logf("route output: %s", output)
require.NoError(t, err, "%s failed")
lines := strings.Split(string(output), "\n")
var intf string
var gateway string
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "interface:") {
intf = strings.TrimSpace(strings.TrimPrefix(line, "interface:"))
} else if strings.HasPrefix(line, "gateway:") {
gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:"))
}
}
require.NotEmpty(t, intf, "interface should be found in route output")
iface, err := net.InterfaceByName(intf)
require.NoError(t, err, "interface %s should exist", intf)
nexthop := Nexthop{Intf: iface}
if gateway != "" && gateway != "link#"+strconv.Itoa(iface.Index) {
addr, err := netip.ParseAddr(gateway)
if err == nil {
nexthop.IP = addr
}
}
return nexthop
}
func TestAddRouteToNonVPNIntf(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
expectError bool
errorType error
}{
{
name: "IPv4 RFC3927 test range",
prefix: netip.MustParsePrefix("198.51.100.0/24"),
},
{
name: "IPv4 Single host",
prefix: netip.MustParsePrefix("8.8.8.8/32"),
},
{
name: "IPv6 External network route",
prefix: netip.MustParsePrefix("2001:db8:1000::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("2001:db8::1/128"),
},
{
name: "IPv6 Subnet",
prefix: netip.MustParsePrefix("2a05:d014:1f8d::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("2a05:d014:1f8d:7302:ebca:ec15:b24d:d07e/128"),
},
// Addresses that should be rejected
{
name: "IPv4 Loopback",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Link-local unicast",
prefix: netip.MustParsePrefix("169.254.1.1/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Multicast",
prefix: netip.MustParsePrefix("239.255.255.250/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Unspecified",
prefix: netip.MustParsePrefix("0.0.0.0/0"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Loopback",
prefix: netip.MustParsePrefix("::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Link-local unicast",
prefix: netip.MustParsePrefix("fe80::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Multicast",
prefix: netip.MustParsePrefix("ff00::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Unspecified",
prefix: netip.MustParsePrefix("::/0"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 WireGuard interface network overlap",
prefix: netip.MustParsePrefix("100.65.75.0/24"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
require.NoError(t, err, "Should be able to get IPv4 default route")
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if testCase.prefix.Addr().Is6() &&
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
t.Skip("Skipping test as no ipv6 default route is available")
}
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
t.Fatalf("Failed to get IPv6 default route: %v", err)
}
var initialNextHop Nexthop
if testCase.prefix.Addr().Is6() {
initialNextHop = initialNextHopV6
} else {
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
initialNextHop = initialNextHopV4
}
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
nexthop, err := r.addRouteToNonVPNIntf(testCase.prefix, wgInterface, initialNextHop)
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
require.NoError(t, err)
if testCase.shouldBeRemoved {
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
} else {
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
}
if testCase.expectError {
require.ErrorIs(t, err, vars.ErrRouteNotAllowed)
return
}
require.NoError(t, err)
t.Logf("Next hop for %s: %s", testCase.prefix, nexthop)
// Verify the route was added and points to non-VPN interface
currentNextHop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err)
assert.NotEqual(t, wgInterface.Name(), currentNextHop.Intf.Name, "Route should not point to VPN interface")
err = r.removeFromRouteTable(testCase.prefix, nexthop)
assert.NoError(t, err)
})
}
}
func TestGetNextHop(t *testing.T) {
if runtime.GOOS == "freebsd" {
t.Skip("skipping on freebsd")
}
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
defaultNh, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if !nexthop.IP.IsValid() {
if !defaultNh.IP.IsValid() {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
@ -133,7 +400,6 @@ func TestGetNextHop(t *testing.T) {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var testingIP string
var testingPrefix netip.Prefix
for _, address := range addresses {
if address.Network() != "ip+net" {
@ -141,213 +407,23 @@ func TestGetNextHop(t *testing.T) {
}
prefix := netip.MustParsePrefix(address.String())
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
testingIP = prefix.Addr().String()
testingPrefix = prefix.Masked()
break
}
}
localIP, err := GetNextHop(testingPrefix.Addr())
nh, err := GetNextHop(testingPrefix.Addr())
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if !localIP.IP.IsValid() {
if nh.Intf == nil {
t.Fatal("should return a gateway for local network")
}
if localIP.IP.String() == nexthop.IP.String() {
t.Fatal("local IP should not match with gateway IP")
if nh.IP.String() == defaultNh.IP.String() {
t.Fatal("next hop IP should not match with default gateway IP")
}
if localIP.IP.String() != testingIP {
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
}
}
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultNexthop: ", defaultNexthop)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
testCases := []struct {
name string
prefix netip.Prefix
preExistingPrefix netip.Prefix
shouldAddRoute bool
}{
{
name: "Should Add And Remove random Route",
prefix: netip.MustParsePrefix("99.99.99.99/32"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
shouldAddRoute: false,
},
{
name: "Should Add Route if bigger network exists",
prefix: netip.MustParsePrefix("100.100.100.0/24"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: true,
},
{
name: "Should Add Route if smaller network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if same network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: false,
},
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun53%d", n),
Address: "100.65.75.2/24",
WGPort: 33100,
WGPrivKey: peerPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgInterface, err := iface.NewWGIFace(opts)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface, nil)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
// test if route exists after adding
ok, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "should not return err")
require.True(t, ok, "route should exist")
// remove route again if added
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err")
}
// route should either not have been added or should have been removed
// In case of already existing route, it should not have been added (but still exist)
ok, err := existsInRouteTable(testCase.prefix)
t.Log("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err")
if !strings.Contains(buf.String(), "because it already exists") {
require.False(t, ok, "route should not exist")
}
})
}
}
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
switch {
case p.Addr().Is6():
continue
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
continue
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
continue
// FreeBSD loopback 127/8 is not added to the routing table
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
continue
default:
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
if nh.Intf.Name != defaultNh.Intf.Name {
t.Fatalf("next hop interface name should match with default gateway interface name, got: %s, want: %s", nh.Intf.Name, defaultNh.Intf.Name)
}
}
@ -384,11 +460,16 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
t.Helper()
err := r.AddVPNRoute(prefix, intf)
require.NoError(t, err, "addVPNRoute should not return err")
if err := r.AddVPNRoute(prefix, intf); err != nil {
if !errors.Is(err, syscall.EEXIST) && !errors.Is(err, vars.ErrRouteNotAllowed) {
t.Fatalf("addVPNRoute should not return err: %v", err)
}
t.Logf("addVPNRoute %v returned: %v", prefix, err)
}
t.Cleanup(func() {
err = r.RemoveVPNRoute(prefix, intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
if err := r.RemoveVPNRoute(prefix, intf); err != nil && !errors.Is(err, vars.ErrRouteNotAllowed) {
t.Fatalf("removeVPNRoute should not return err: %v", err)
}
})
}
@ -422,28 +503,10 @@ func setupTestEnv(t *testing.T) {
// 10.10.0.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
// 127.0.10.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
// unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
}
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
return
}
prefixNexthop, err := GetNextHop(prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
}
}
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string

View File

@ -149,6 +149,10 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
if !nbnet.AdvancedRouting() {
return r.genericAddVPNRoute(prefix, intf)
}
@ -172,6 +176,10 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
if !nbnet.AdvancedRouting() {
return r.genericRemoveVPNRoute(prefix, intf)
}
@ -219,7 +227,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
ones, _ := route.Dst.Mask.Size()
prefix := netip.PrefixFrom(addr, ones)
prefix := netip.PrefixFrom(addr.Unmap(), ones)
if prefix.IsValid() {
prefixList = append(prefixList, prefix)
}
@ -247,7 +255,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
return fmt.Errorf("add gateway and device: %w", err)
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
return fmt.Errorf("netlink add route: %w", err)
}
@ -270,7 +278,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
Dst: ipNet,
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
return fmt.Errorf("netlink add unreachable route: %w", err)
}

View File

@ -19,7 +19,6 @@ import (
)
var expectedVPNint = "wgtest0"
var expectedLoopbackInt = "lo"
var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0"
@ -31,12 +30,6 @@ func init() {
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
},
{
name: "To more specific route (local) without custom dialer via physical interface",
expectedInterface: expectedLoopbackInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}...)
}

View File

@ -11,10 +11,16 @@ import (
)
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
return r.genericAddVPNRoute(prefix, intf)
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
return r.genericRemoveVPNRoute(prefix, intf)
}

View File

@ -0,0 +1,268 @@
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type mockWGIface struct {
address wgaddr.Address
name string
}
func (m *mockWGIface) Address() wgaddr.Address {
return m.address
}
func (m *mockWGIface) Name() string {
return m.name
}
func TestSysOps_validateRoute(t *testing.T) {
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wg0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
tests := []struct {
name string
prefix string
expectError bool
}{
// Valid routes
{
name: "valid IPv4 route",
prefix: "192.168.1.0/24",
expectError: false,
},
{
name: "valid IPv6 route",
prefix: "2001:db8::/32",
expectError: false,
},
{
name: "valid single IPv4 host",
prefix: "8.8.8.8/32",
expectError: false,
},
{
name: "valid single IPv6 host",
prefix: "2001:4860:4860::8888/128",
expectError: false,
},
// Invalid routes - loopback
{
name: "IPv4 loopback",
prefix: "127.0.0.1/32",
expectError: true,
},
{
name: "IPv6 loopback",
prefix: "::1/128",
expectError: true,
},
// Invalid routes - link-local unicast
{
name: "IPv4 link-local unicast",
prefix: "169.254.1.1/32",
expectError: true,
},
{
name: "IPv6 link-local unicast",
prefix: "fe80::1/128",
expectError: true,
},
// Invalid routes - multicast
{
name: "IPv4 multicast",
prefix: "224.0.0.1/32",
expectError: true,
},
{
name: "IPv6 multicast",
prefix: "ff02::1/128",
expectError: true,
},
// Invalid routes - link-local multicast
{
name: "IPv4 link-local multicast",
prefix: "224.0.0.0/24",
expectError: true,
},
{
name: "IPv6 link-local multicast",
prefix: "ff02::/16",
expectError: true,
},
// Invalid routes - interface-local multicast (IPv6 only)
{
name: "IPv6 interface-local multicast",
prefix: "ff01::1/128",
expectError: true,
},
// Invalid routes - overlaps with WG interface network
{
name: "overlaps with WG network - exact match",
prefix: "10.0.0.0/24",
expectError: true,
},
{
name: "overlaps with WG network - subset",
prefix: "10.0.0.1/32",
expectError: true,
},
{
name: "overlaps with WG network - host in range",
prefix: "10.0.0.100/32",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefix, err := netip.ParsePrefix(tt.prefix)
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
err = sysOps.validateRoute(prefix)
if tt.expectError {
require.Error(t, err, "validateRoute() expected error for %s", tt.prefix)
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s", tt.prefix)
} else {
assert.NoError(t, err, "validateRoute() expected no error for %s", tt.prefix)
}
})
}
}
func TestSysOps_validateRoute_SubnetOverlap(t *testing.T) {
wgNetwork := netip.MustParsePrefix("192.168.100.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wg0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
tests := []struct {
name string
prefix string
expectError bool
description string
}{
{
name: "identical subnet",
prefix: "192.168.100.0/24",
expectError: true,
description: "exact same network as WG interface",
},
{
name: "broader subnet containing WG network",
prefix: "192.168.0.0/16",
expectError: false,
description: "broader network that contains WG network should be allowed",
},
{
name: "host within WG network",
prefix: "192.168.100.50/32",
expectError: true,
description: "specific host within WG network",
},
{
name: "subnet within WG network",
prefix: "192.168.100.128/25",
expectError: true,
description: "smaller subnet within WG network",
},
{
name: "adjacent subnet - same /23",
prefix: "192.168.101.0/24",
expectError: false,
description: "adjacent subnet, no overlap",
},
{
name: "adjacent subnet - different /16",
prefix: "192.167.100.0/24",
expectError: false,
description: "different network, no overlap",
},
{
name: "WG network broadcast address",
prefix: "192.168.100.255/32",
expectError: true,
description: "broadcast address of WG network",
},
{
name: "WG network first usable",
prefix: "192.168.100.1/32",
expectError: true,
description: "first usable address in WG network",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefix, err := netip.ParsePrefix(tt.prefix)
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
err = sysOps.validateRoute(prefix)
if tt.expectError {
require.Error(t, err, "validateRoute() expected error for %s (%s)", tt.prefix, tt.description)
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s (%s)", tt.prefix, tt.description)
} else {
assert.NoError(t, err, "validateRoute() expected no error for %s (%s)", tt.prefix, tt.description)
}
})
}
}
func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wt0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
var invalidPrefix netip.Prefix
err := sysOps.validateRoute(invalidPrefix)
require.Error(t, err, "validateRoute() expected error for invalid prefix")
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for invalid prefix")
}

View File

@ -3,15 +3,19 @@
package systemops
import (
"errors"
"fmt"
"net"
"net/netip"
"os/exec"
"strings"
"strconv"
"syscall"
"time"
"unsafe"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
@ -26,48 +30,16 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("add", prefix, nexthop)
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("delete", prefix, nexthop)
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
}
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
inet := "-inet"
if prefix.Addr().Is6() {
inet = "-inet6"
}
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
args := []string{"-n", action, inet, network}
if nexthop.IP.IsValid() {
args = append(args, nexthop.IP.Unmap().String())
} else if nexthop.Intf != nil {
args = append(args, "-interface", nexthop.Intf.Name)
}
if err := retryRouteCmd(args); err != nil {
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
}
return nil
}
func retryRouteCmd(args []string) error {
operation := func() error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
// https://github.com/golang/go/issues/45736
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
if !prefix.IsValid() {
return fmt.Errorf("invalid prefix: %s", prefix)
}
expBackOff := backoff.NewExponentialBackOff()
@ -75,9 +47,157 @@ func retryRouteCmd(args []string) error {
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff)
if err != nil {
return fmt.Errorf("route cmd retry failed: %w", err)
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
a := "add"
if action == unix.RTM_DELETE {
a = "remove"
}
return fmt.Errorf("%s route for %s: %w", a, prefix, err)
}
return nil
}
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
operation := func() error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
}
}()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
}
msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
}
if _, err = unix.Write(fd, msgBytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
}
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil
}
return operation
}
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{
Type: action,
Flags: unix.RTF_UP,
Version: unix.RTM_VERSION,
Seq: 1,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
addrs[unix.RTAX_DST], err = addrToRouteAddr(prefix.Addr())
if err != nil {
return nil, fmt.Errorf("build destination address for %s: %w", prefix.Addr(), err)
}
if prefix.IsSingleIP() {
msg.Flags |= unix.RTF_HOST
} else {
addrs[unix.RTAX_NETMASK], err = prefixToRouteNetmask(prefix)
if err != nil {
return nil, fmt.Errorf("build netmask for %s: %w", prefix, err)
}
}
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
addrs[unix.RTAX_GATEWAY], err = addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return nil, fmt.Errorf("build gateway IP address for %s: %w", nexthop.IP, err)
}
} else if nexthop.Intf != nil {
msg.Index = nexthop.Intf.Index
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return msg, nil
}
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() {
return &route.Inet4Addr{IP: addr.As4()}, nil
}
if addr.Zone() == "" {
return &route.Inet6Addr{IP: addr.As16()}, nil
}
var zone int
// zone can be either a numeric zone ID or an interface name.
if z, err := strconv.Atoi(addr.Zone()); err == nil {
zone = z
} else {
iface, err := net.InterfaceByName(addr.Zone())
if err != nil {
return nil, fmt.Errorf("resolve zone '%s': %w", addr.Zone(), err)
}
zone = iface.Index
}
return &route.Inet6Addr{IP: addr.As16(), ZoneID: zone}, nil
}
func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
bits := prefix.Bits()
if prefix.Addr().Is4() {
m := net.CIDRMask(bits, 32)
var maskBytes [4]byte
copy(maskBytes[:], m)
return &route.Inet4Addr{IP: maskBytes}, nil
}
if prefix.Addr().Is6() {
m := net.CIDRMask(bits, 128)
var maskBytes [16]byte
copy(maskBytes[:], m)
return &route.Inet6Addr{IP: maskBytes}, nil
}
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
}

View File

@ -1,5 +1,3 @@
//go:build windows
package systemops
import (
@ -9,9 +7,8 @@ import (
"net"
"net/netip"
"os"
"os/exec"
"runtime/debug"
"strconv"
"strings"
"sync"
"syscall"
"time"
@ -21,11 +18,12 @@ import (
"github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const InfiniteLifetime = 0xffffffff
type RouteUpdateType int
// RouteUpdate represents a change in the routing table.
@ -58,9 +56,13 @@ type MSFT_NetRoute struct {
AddressFamily uint16
}
// MIB_IPFORWARD_ROW2 is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
// luid represents a locally unique identifier for network interfaces
type luid uint64
// MIB_IPFORWARD_ROW2 represents a route entry in the routing table.
// It is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
type MIB_IPFORWARD_ROW2 struct {
InterfaceLuid uint64
InterfaceLuid luid
InterfaceIndex uint32
DestinationPrefix IP_ADDRESS_PREFIX
NextHop SOCKADDR_INET_NEXTHOP
@ -108,9 +110,14 @@ type SOCKADDR_INET_NEXTHOP struct {
type MIB_NOTIFICATION_TYPE int32
var (
modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2")
procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2")
procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2")
procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry")
procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid")
prefixList []netip.Prefix
lastUpdate time.Time
@ -139,6 +146,8 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
log.Debugf("Adding route to %s via %s", prefix, nexthop)
// if we don't have an interface but a zone, extract the interface index from the zone
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
zone, err := strconv.Atoi(nexthop.IP.Zone())
if err != nil {
@ -147,23 +156,187 @@ func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
nexthop.Intf = &net.Interface{Index: zone}
}
return addRouteCmd(prefix, nexthop)
return addRoute(prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"delete", prefix.String()}
if nexthop.IP.IsValid() {
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
log.Debugf("Removing route to %s via %s", prefix, nexthop)
return deleteRoute(prefix, nexthop)
}
// setupRouteEntry prepares a route entry with common configuration
func setupRouteEntry(prefix netip.Prefix, nexthop Nexthop) (*MIB_IPFORWARD_ROW2, error) {
route := &MIB_IPFORWARD_ROW2{}
initializeIPForwardEntry(route)
// Convert interface index to luid if interface is specified
if nexthop.Intf != nil {
var luid luid
if err := convertInterfaceIndexToLUID(uint32(nexthop.Intf.Index), &luid); err != nil {
return nil, fmt.Errorf("convert interface index to luid: %w", err)
}
route.InterfaceLuid = luid
route.InterfaceIndex = uint32(nexthop.Intf.Index)
}
routeCmd := uspfilter.GetSystem32Command("route")
if err := setDestinationPrefix(&route.DestinationPrefix, prefix); err != nil {
return nil, fmt.Errorf("set destination prefix: %w", err)
}
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if nexthop.IP.IsValid() {
if err := setNextHop(&route.NextHop, nexthop.IP); err != nil {
return nil, fmt.Errorf("set next hop: %w", err)
}
}
if err != nil {
return fmt.Errorf("remove route: %w", err)
return route, nil
}
// addRoute adds a route using Windows iphelper APIs
func addRoute(prefix netip.Prefix, nexthop Nexthop) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in addRoute: %v, stack trace: %s", r, debug.Stack())
}
}()
route, setupErr := setupRouteEntry(prefix, nexthop)
if setupErr != nil {
return fmt.Errorf("setup route entry: %w", setupErr)
}
route.Metric = 1
route.ValidLifetime = InfiniteLifetime
route.PreferredLifetime = InfiniteLifetime
return createIPForwardEntry2(route)
}
// deleteRoute deletes a route using Windows iphelper APIs
func deleteRoute(prefix netip.Prefix, nexthop Nexthop) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in deleteRoute: %v, stack trace: %s", r, debug.Stack())
}
}()
route, setupErr := setupRouteEntry(prefix, nexthop)
if setupErr != nil {
return fmt.Errorf("setup route entry: %w", setupErr)
}
if err := getIPForwardEntry2(route); err != nil {
return fmt.Errorf("get route entry: %w", err)
}
return deleteIPForwardEntry2(route)
}
// setDestinationPrefix sets the destination prefix in the route structure
func setDestinationPrefix(prefix *IP_ADDRESS_PREFIX, dest netip.Prefix) error {
addr := dest.Addr()
prefix.PrefixLength = uint8(dest.Bits())
if addr.Is4() {
prefix.Prefix.sin6_family = windows.AF_INET
ip4 := addr.As4()
binary.BigEndian.PutUint32(prefix.Prefix.data[:4],
uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3]))
return nil
}
if addr.Is6() {
prefix.Prefix.sin6_family = windows.AF_INET6
ip6 := addr.As16()
copy(prefix.Prefix.data[4:20], ip6[:])
if zone := addr.Zone(); zone != "" {
if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil {
binary.BigEndian.PutUint32(prefix.Prefix.data[20:24], uint32(scopeID))
}
}
return nil
}
return fmt.Errorf("invalid address family")
}
// setNextHop sets the next hop address in the route structure
func setNextHop(nextHop *SOCKADDR_INET_NEXTHOP, addr netip.Addr) error {
if addr.Is4() {
nextHop.sin6_family = windows.AF_INET
ip4 := addr.As4()
binary.BigEndian.PutUint32(nextHop.data[:4],
uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3]))
return nil
}
if addr.Is6() {
nextHop.sin6_family = windows.AF_INET6
ip6 := addr.As16()
copy(nextHop.data[4:20], ip6[:])
// Handle zone if present
if zone := addr.Zone(); zone != "" {
if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil {
binary.BigEndian.PutUint32(nextHop.data[20:24], uint32(scopeID))
}
}
return nil
}
return fmt.Errorf("invalid address family")
}
// Windows API wrappers
func createIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
r1, _, e1 := syscall.SyscallN(procCreateIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("CreateIpForwardEntry2: %w", e1)
}
return fmt.Errorf("CreateIpForwardEntry2: code %d", r1)
}
return nil
}
func deleteIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
r1, _, e1 := syscall.SyscallN(procDeleteIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("DeleteIpForwardEntry2: %w", e1)
}
return fmt.Errorf("DeleteIpForwardEntry2: code %d", r1)
}
return nil
}
func getIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
r1, _, e1 := syscall.SyscallN(procGetIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("GetIpForwardEntry2: %w", e1)
}
return fmt.Errorf("GetIpForwardEntry2: code %d", r1)
}
return nil
}
// https://learn.microsoft.com/en-us/windows/win32/api/netioapi/nf-netioapi-initializeipforwardentry
func initializeIPForwardEntry(route *MIB_IPFORWARD_ROW2) {
// Does not return anything. Trying to handle the error might return an uninitialized value.
_, _, _ = syscall.SyscallN(procInitializeIpForwardEntry.Addr(), uintptr(unsafe.Pointer(route)))
}
func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *luid) error {
r1, _, e1 := syscall.SyscallN(procConvertInterfaceIndexToLuid.Addr(),
uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("ConvertInterfaceIndexToLuid: %w", e1)
}
return fmt.Errorf("ConvertInterfaceIndexToLuid: code %d", r1)
}
return nil
}
@ -319,7 +492,7 @@ func cancelMibChangeNotify2(handle windows.Handle) error {
}
// GetRoutesFromTable returns the current routing table from with prefixes only.
// It ccaches the result for 2 seconds to avoid blocking the caller.
// It caches the result for 2 seconds to avoid blocking the caller.
func GetRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()
@ -388,35 +561,6 @@ func GetRoutes() ([]Route, error) {
return routes, nil
}
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"add", prefix.String()}
if nexthop.IP.IsValid() {
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
addr = "::"
}
args = append(args, addr)
}
if nexthop.Intf != nil {
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
}
return nil
}
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}

View File

@ -5,18 +5,23 @@ import (
"encoding/json"
"fmt"
"net"
"net/netip"
"os/exec"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
)
var expectedExtInt = "Ethernet1"
var (
expectedExternalInt = "Ethernet1"
expectedVPNint = "wgtest0"
)
type RouteInfo struct {
NextHop string `json:"nexthop"`
@ -43,8 +48,6 @@ type testCase struct {
dialer dialer
}
var expectedVPNint = "wgtest0"
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
@ -52,14 +55,14 @@ var testCases = []testCase{
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "128.0.0.0/1",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
},
{
name: "To external host with custom dialer via physical interface",
destination: "192.0.2.1:53",
expectedDestPrefix: "192.0.2.1/32",
expectedInterface: expectedExtInt,
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
},
@ -67,24 +70,15 @@ var testCases = []testCase{
name: "To duplicate internal route with custom dialer via physical interface",
destination: "10.0.0.2:53",
expectedDestPrefix: "10.0.0.2/32",
expectedInterface: expectedExtInt,
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
dialer: &net.Dialer{},
},
{
name: "To unique vpn route with custom dialer via physical interface",
destination: "172.16.0.2:53",
expectedDestPrefix: "172.16.0.2/32",
expectedInterface: expectedExtInt,
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
},
{
@ -93,7 +87,7 @@ var testCases = []testCase{
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "172.16.0.0/12",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
},
@ -103,22 +97,14 @@ var testCases = []testCase{
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "10.10.0.0/24",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53",
expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
},
}
func TestRouting(t *testing.T) {
log.SetLevel(log.DebugLevel)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t)
@ -129,7 +115,7 @@ func TestRouting(t *testing.T) {
require.NoError(t, err, "Failed to fetch interface IP")
output := testRoute(t, tc.destination, tc.dialer)
if tc.expectedInterface == expectedExtInt {
if tc.expectedInterface == expectedExternalInt {
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
} else {
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
@ -242,19 +228,23 @@ func setupDummyInterfacesAndRoutes(t *testing.T) {
func addDummyRoute(t *testing.T, dstCIDR string) {
t.Helper()
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
prefix, err := netip.ParsePrefix(dstCIDR)
if err != nil {
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output)
t.FailNow()
t.Fatalf("Failed to parse destination CIDR %s: %v", dstCIDR, err)
}
nexthop := Nexthop{
Intf: &net.Interface{Index: 1},
}
if err = addRoute(prefix, nexthop); err != nil {
t.Fatalf("Failed to add dummy route: %v", err)
}
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
err := deleteRoute(prefix, nexthop)
if err != nil {
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output)
t.Logf("Failed to remove dummy route: %v", err)
}
})
}

View File

@ -3,11 +3,11 @@ package server
import (
"context"
"fmt"
"net"
"net/netip"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
@ -19,81 +19,32 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
s.mutex.Lock()
defer s.mutex.Unlock()
if s.connectClient == nil {
return nil, fmt.Errorf("connect client not initialized")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("engine not initialized")
tracer, engine, err := s.getPacketTracer()
if err != nil {
return nil, err
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, fmt.Errorf("firewall manager not initialized")
srcAddr, err := s.parseAddress(req.GetSourceIp(), engine)
if err != nil {
return nil, fmt.Errorf("invalid source IP address: %w", err)
}
tracer, ok := fwManager.(packetTracer)
if !ok {
return nil, fmt.Errorf("firewall manager does not support packet tracing")
dstAddr, err := s.parseAddress(req.GetDestinationIp(), engine)
if err != nil {
return nil, fmt.Errorf("invalid destination IP address: %w", err)
}
srcIP := net.ParseIP(req.GetSourceIp())
if req.GetSourceIp() == "self" {
srcIP = engine.GetWgAddr()
protocol, err := s.parseProtocol(req.GetProtocol())
if err != nil {
return nil, err
}
srcAddr, ok := netip.AddrFromSlice(srcIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
direction, err := s.parseDirection(req.GetDirection())
if err != nil {
return nil, err
}
dstIP := net.ParseIP(req.GetDestinationIp())
if req.GetDestinationIp() == "self" {
dstIP = engine.GetWgAddr()
}
dstAddr, ok := netip.AddrFromSlice(dstIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
}
if srcIP == nil || dstIP == nil {
return nil, fmt.Errorf("invalid IP address")
}
var tcpState *uspfilter.TCPState
if flags := req.GetTcpFlags(); flags != nil {
tcpState = &uspfilter.TCPState{
SYN: flags.GetSyn(),
ACK: flags.GetAck(),
FIN: flags.GetFin(),
RST: flags.GetRst(),
PSH: flags.GetPsh(),
URG: flags.GetUrg(),
}
}
var dir fw.RuleDirection
switch req.GetDirection() {
case "in":
dir = fw.RuleDirectionIN
case "out":
dir = fw.RuleDirectionOUT
default:
return nil, fmt.Errorf("invalid direction")
}
var protocol fw.Protocol
switch req.GetProtocol() {
case "tcp":
protocol = fw.ProtocolTCP
case "udp":
protocol = fw.ProtocolUDP
case "icmp":
protocol = fw.ProtocolICMP
default:
return nil, fmt.Errorf("invalid protocolcol")
}
tcpState := s.parseTCPFlags(req.GetTcpFlags())
builder := &uspfilter.PacketBuilder{
SrcIP: srcAddr,
@ -101,16 +52,96 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
Protocol: protocol,
SrcPort: uint16(req.GetSourcePort()),
DstPort: uint16(req.GetDestinationPort()),
Direction: dir,
Direction: direction,
TCPState: tcpState,
ICMPType: uint8(req.GetIcmpType()),
ICMPCode: uint8(req.GetIcmpCode()),
}
trace, err := tracer.TracePacketFromBuilder(builder)
if err != nil {
return nil, fmt.Errorf("trace packet: %w", err)
}
return s.buildTraceResponse(trace), nil
}
func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
if s.connectClient == nil {
return nil, nil, fmt.Errorf("connect client not initialized")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, nil, fmt.Errorf("engine not initialized")
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, nil, fmt.Errorf("firewall manager not initialized")
}
tracer, ok := fwManager.(packetTracer)
if !ok {
return nil, nil, fmt.Errorf("firewall manager does not support packet tracing")
}
return tracer, engine, nil
}
func (s *Server) parseAddress(addr string, engine *internal.Engine) (netip.Addr, error) {
if addr == "self" {
return engine.GetWgAddr(), nil
}
a, err := netip.ParseAddr(addr)
if err != nil {
return netip.Addr{}, err
}
return a.Unmap(), nil
}
func (s *Server) parseProtocol(protocol string) (fw.Protocol, error) {
switch protocol {
case "tcp":
return fw.ProtocolTCP, nil
case "udp":
return fw.ProtocolUDP, nil
case "icmp":
return fw.ProtocolICMP, nil
default:
return "", fmt.Errorf("invalid protocol")
}
}
func (s *Server) parseDirection(direction string) (fw.RuleDirection, error) {
switch direction {
case "in":
return fw.RuleDirectionIN, nil
case "out":
return fw.RuleDirectionOUT, nil
default:
return 0, fmt.Errorf("invalid direction")
}
}
func (s *Server) parseTCPFlags(flags *proto.TCPFlags) *uspfilter.TCPState {
if flags == nil {
return nil
}
return &uspfilter.TCPState{
SYN: flags.GetSyn(),
ACK: flags.GetAck(),
FIN: flags.GetFin(),
RST: flags.GetRst(),
PSH: flags.GetPsh(),
URG: flags.GetUrg(),
}
}
func (s *Server) buildTraceResponse(trace *uspfilter.PacketTrace) *proto.TracePacketResponse {
resp := &proto.TracePacketResponse{}
for _, result := range trace.Results {
@ -119,10 +150,12 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
Message: result.Message,
Allowed: result.Allowed,
}
if result.ForwarderAction != nil {
details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr)
stage.ForwardingDetails = &details
}
resp.Stages = append(resp.Stages, stage)
}
@ -130,5 +163,5 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed
}
return resp, nil
return resp
}

View File

@ -1,8 +1,10 @@
package net
import (
"fmt"
"math/big"
"net"
"net/netip"
"github.com/google/uuid"
)
@ -54,11 +56,13 @@ func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP {
// Calculate the last IP in the CIDR range
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
addr := network.Addr().AsSlice()
mask := net.CIDRMask(network.Bits(), len(addr)*8)
for i := 0; i < len(addr); i++ {
endIP = append(endIP, addr[i]|^mask[i])
}
// convert to big.Int
@ -70,5 +74,10 @@ func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP {
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return resultInt.Bytes()
ip, ok := netip.AddrFromSlice(resultInt.Bytes())
if !ok {
return netip.Addr{}, fmt.Errorf("invalid IP address from network %s", network)
}
return ip.Unmap(), nil
}

94
util/net/net_test.go Normal file
View File

@ -0,0 +1,94 @@
package net
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
name string
network string
fromEnd int
expected string
expectErr bool
}{
{
name: "IPv4 /24 network - last IP (fromEnd=0)",
network: "192.168.1.0/24",
fromEnd: 0,
expected: "192.168.1.255",
},
{
name: "IPv4 /24 network - fromEnd=1",
network: "192.168.1.0/24",
fromEnd: 1,
expected: "192.168.1.254",
},
{
name: "IPv4 /24 network - fromEnd=5",
network: "192.168.1.0/24",
fromEnd: 5,
expected: "192.168.1.250",
},
{
name: "IPv4 /16 network - last IP",
network: "10.0.0.0/16",
fromEnd: 0,
expected: "10.0.255.255",
},
{
name: "IPv4 /16 network - fromEnd=256",
network: "10.0.0.0/16",
fromEnd: 256,
expected: "10.0.254.255",
},
{
name: "IPv4 /32 network - single host",
network: "192.168.1.100/32",
fromEnd: 0,
expected: "192.168.1.100",
},
{
name: "IPv6 /64 network - last IP",
network: "2001:db8::/64",
fromEnd: 0,
expected: "2001:db8::ffff:ffff:ffff:ffff",
},
{
name: "IPv6 /64 network - fromEnd=1",
network: "2001:db8::/64",
fromEnd: 1,
expected: "2001:db8::ffff:ffff:ffff:fffe",
},
{
name: "IPv6 /128 network - single host",
network: "2001:db8::1/128",
fromEnd: 0,
expected: "2001:db8::1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
network, err := netip.ParsePrefix(tt.network)
require.NoError(t, err, "Failed to parse network prefix")
result, err := GetLastIPFromNetwork(network, tt.fromEnd)
if tt.expectErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
expectedIP, err := netip.ParseAddr(tt.expected)
require.NoError(t, err, "Failed to parse expected IP")
assert.Equal(t, expectedIP, result, "IP mismatch for network %s with fromEnd=%d", tt.network, tt.fromEnd)
})
}
}