[client] Use platform-native routing APIs for freeBSD, macOS and Windows

This commit is contained in:
Viktor Liu 2025-06-04 16:28:58 +02:00 committed by GitHub
parent 87148c503f
commit ea4d13e96d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
53 changed files with 1552 additions and 1046 deletions

View File

@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: ` Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53 netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0 netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`, netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3), Args: cobra.ExactArgs(3),
RunE: tracePacket, RunE: tracePacket,

View File

@ -2,7 +2,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net" "net/netip"
"testing" "testing"
"time" "time"
@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
IsRange: true, IsRange: true,
Values: []uint16{8043, 8046}, Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "") rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Close(nil) err = manager.Close(nil)
@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default") rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ip := net.ParseIP("10.20.0.100") ip := netip.MustParseAddr("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@ -3,7 +3,6 @@ package nftables
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: netip.MustParseAddr("100.96.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := net.ParseIP("100.96.0.1") ip := netip.MustParseAddr("100.96.0.1").Unmap()
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "") rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
} }
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{ expectedExprs2 := []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: add.AsSlice(), Data: ip.AsSlice(),
}, },
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: netip.MustParseAddr("100.96.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := net.ParseIP("10.20.0.100") ip := netip.MustParseAddr("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
}) })
ip := net.ParseIP("100.96.0.1") ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(

View File

@ -41,7 +41,7 @@ type Forwarder struct {
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip net.IP ip tcpip.Address
netstack bool netstack bool
} }
@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err) return nil, fmt.Errorf("failed to create NIC: %v", err)
} }
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()), Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
PrefixLen: ones, PrefixLen: iface.Address().Network.Bits(),
}, },
} }
@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: iface.Address().IP, ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
} }
receiveWindow := defaultReceiveWindow receiveWindow := defaultReceiveWindow
@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
} }
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) { if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1) return net.IPv4(127, 0, 0, 1)
} }
return addr.AsSlice() return addr.AsSlice()
@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
} }
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) { func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok { } else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {

View File

@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
} }
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if ipv4 := ip.To4(); ipv4 != nil { if !ip.Is4() {
high := uint16(ipv4[0]) return
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) }
ipv4 := ip.AsSlice()
if bitmap[high] == nil { high := uint16(ipv4[0])
bitmap[high] = &ipv4LowBitmap{} low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
}
index := low / 32 if bitmap[high] == nil {
bit := low % 32 bitmap[high] = &ipv4LowBitmap{}
bitmap[high].bitmap[index] |= 1 << bit }
ipStr := ipv4.String() index := low / 32
if _, exists := ipv4Set[ipStr]; !exists { bit := low % 32
ipv4Set[ipStr] = struct{}{} bitmap[high].bitmap[index] |= 1 << bit
*ipv4Addresses = append(*ipv4Addresses, ipStr)
} if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
} }
} }
@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0 return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
} }
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil return nil
} }
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue continue
} }
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil { addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err) log.Debugf("process IP failed: %v", err)
} }
} }
@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}() }()
var newIPv4Bitmap [256]*ipv4LowBitmap var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[string]struct{}) ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []string var ipv4Addresses []netip.Addr
// 127.0.0.0/8 // 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{} newIPv4Bitmap[127] = &ipv4LowBitmap{}

View File

@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match - addresses 32 apart", name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.33"), testIP: netip.MustParseAddr("192.168.1.33"),
expected: false, expected: false,
@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"), IP: netip.MustParseAddr("fe80::1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
}, },
testIP: netip.MustParseAddr("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,

View File

@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }

View File

@ -71,7 +71,6 @@ type Manager struct {
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface common.IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
@ -1091,11 +1090,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return true return true
} }
// SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not

View File

@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup // Apply scenario-specific setup
sc.setupFunc(manager) sc.setupFunc(manager)
@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table // Pre-populate connection table
srcIPs := generateRandomIPs(count) srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count) dstIPs := generateRandomIPs(count)
@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0] srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0] dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "post_handshake", state: "post_handshake",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err) require.NoError(b, err)
@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
for _, r := range rules { for _, r := range rules {
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept) dst := fw.Network{Prefix: r.dest}
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@ -19,12 +19,8 @@ import (
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
localIP := net.ParseIP("100.10.0.100") localIP := netip.MustParseAddr("100.10.0.100")
wgNet := &net.IPNet{ wgNet := netip.MustParsePrefix("100.10.0.0/16")
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
manager.wgNetwork = wgNet
err = manager.UpdateLocalIPs() err = manager.UpdateLocalIPs()
require.NoError(t, err) require.NoError(t, err)
@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
dev := mocks.NewMockDevice(ctrl) dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes() dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
localIP, wgNet, err := net.ParseCIDR(network) wgNet := netip.MustParsePrefix(network)
require.NoError(tb, err)
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: localIP, IP: wgNet.Addr(),
Network: wgNet, Network: wgNet,
} }
}, },
@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }

View File

@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }
@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{

View File

@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a.AsSlice()) { if u.address.Network.Contains(a) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
} }

View File

@ -1,7 +1,6 @@
package device package device
import ( import (
"net"
"net/netip" "net/netip"
"sync" "sync"
@ -24,9 +23,6 @@ type PacketFilter interface {
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
} }
// FilteredDevice to override Read or Write of packets // FilteredDevice to override Read or Write of packets

View File

@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface") log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP // TODO: get from service listener runtime IP
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1) dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
if err != nil {
return nil, fmt.Errorf("last ip: %w", err)
}
log.Debugf("netstack using address: %s", t.address.IP) log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr) log.Debugf("netstack using dns address: %s", dnsAddr)

View File

@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
} }
ip := address.IP.String() ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)

View File

@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
} }
w.filter = filter w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter) w.tun.FilteredDevice().SetFilter(filter)
return nil return nil

View File

@ -5,7 +5,6 @@
package mocks package mocks
import ( import (
net "net"
"net/netip" "net/netip"
reflect "reflect" reflect "reflect"
@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
} }
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@ -1,8 +1,6 @@
package netstack package netstack
import ( import (
"fmt"
"net"
"net/netip" "net/netip"
"os" "os"
"strconv" "strconv"
@ -15,8 +13,8 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive type NetStackTun struct { //nolint:revive
address net.IP address netip.Addr
dnsAddress net.IP dnsAddress netip.Addr
mtu int mtu int
listenAddress string listenAddress string
@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device tundev tun.Device
} }
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun { func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
return &NetStackTun{ return &NetStackTun{
address: address, address: address,
dnsAddress: dnsAddress, dnsAddress: dnsAddress,
@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
} }
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN( nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{addr.Unmap()}, []netip.Addr{t.address},
[]netip.Addr{dnsAddr.Unmap()}, []netip.Addr{t.dnsAddress},
t.mtu) t.mtu)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@ -2,28 +2,27 @@ package wgaddr
import ( import (
"fmt" "fmt"
"net" "net/netip"
) )
// Address WireGuard parsed address // Address WireGuard parsed address
type Address struct { type Address struct {
IP net.IP IP netip.Addr
Network *net.IPNet Network netip.Prefix
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) { func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address) prefix, err := netip.ParsePrefix(address)
if err != nil { if err != nil {
return Address{}, err return Address{}, err
} }
return Address{ return Address{
IP: ip, IP: prefix.Addr().Unmap(),
Network: network, Network: prefix.Masked(),
}, nil }, nil
} }
func (addr Address) String() string { func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size() return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@ -1,7 +1,7 @@
package acl package acl
import ( import (
"net" "net/netip"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -43,12 +43,11 @@ func TestDefaultManager(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
require.NoError(t, err)
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
@ -162,12 +161,11 @@ func TestDefaultManagerStateless(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
require.NoError(t, err)
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
@ -372,12 +370,11 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
require.NoError(t, err)
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()

View File

@ -2,7 +2,7 @@ package internal
import ( import (
"fmt" "fmt"
"net" "net/netip"
"slices" "slices"
"strings" "strings"
@ -12,13 +12,14 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) { func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData) ip, err := netip.ParseAddr(aRecord.RData)
if ip == nil || ip.To4() == nil { if err != nil {
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
if !ipNet.Contains(ip) { if !prefix.Contains(ip) {
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
} }
// generateReverseZoneName creates the reverse DNS zone name for a given network // generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(ipNet *net.IPNet) (string, error) { func generateReverseZoneName(network netip.Prefix) (string, error) {
networkIP := ipNet.IP.Mask(ipNet.Mask) networkIP := network.Masked().Addr()
maskOnes, _ := ipNet.Mask.Size()
if !networkIP.Is4() {
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
}
// round up to nearest byte // round up to nearest byte
octetsToUse := (maskOnes + 7) / 8 octetsToUse := (network.Bits() + 7) / 8
octets := strings.Split(networkIP.String(), ".") octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) { if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes) return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
} }
reverseOctets := make([]string, octetsToUse) reverseOctets := make([]string, octetsToUse)
@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
} }
// collectPTRRecords gathers all PTR records for the given network from A records // collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord { func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones { for _, zone := range config.CustomZones {
@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
continue continue
} }
if ptrRecord, ok := createPTRRecord(record, ipNet); ok { if ptrRecord, ok := createPTRRecord(record, prefix); ok {
records = append(records, ptrRecord) records = append(records, ptrRecord)
} }
} }
@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
} }
// addReverseZone adds a reverse DNS zone to the configuration for the given network // addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) { func addReverseZone(config *nbdns.Config, network netip.Prefix) {
zoneName, err := generateReverseZoneName(ipNet) zoneName, err := generateReverseZoneName(network)
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
return return
@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
return return
} }
records := collectPTRRecords(config, ipNet) records := collectPTRRecords(config, network)
reverseZone := nbdns.CustomZone{ reverseZone := nbdns.CustomZone{
Domain: zoneName, Domain: zoneName,

View File

@ -46,10 +46,9 @@ func (w *mocWGIface) Name() string {
} }
func (w *mocWGIface) Address() wgaddr.Address { func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return wgaddr.Address{ return wgaddr.Address{
IP: ip, IP: netip.MustParseAddr("100.66.100.1"),
Network: network, Network: netip.MustParsePrefix("100.66.100.0/24"),
} }
} }
@ -464,17 +463,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes() packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil { if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err) t.Errorf("set packet filter: %v", err)

View File

@ -24,11 +24,15 @@ type ServiceViaMemory struct {
} }
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{ s := &ServiceViaMemory{
wgInterface: wgIface, wgInterface: wgIface,
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(), runtimeIP: lastIP.String(),
runtimePort: defaultPort, runtimePort: defaultPort,
} }
return s return s
@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
} }
firstLayerDecoder := layers.LayerTypeIPv4 firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil { if s.wgInterface.Address().IP.Is6() {
firstLayerDecoder = layers.LayerTypeIPv6 firstLayerDecoder = layers.LayerTypeIPv6
} }

View File

@ -1,33 +0,0 @@
package dns
import (
"net"
"testing"
nbnet "github.com/netbirdio/netbird/util/net"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@ -3,6 +3,7 @@ package dns
import ( import (
"context" "context"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@ -23,8 +24,8 @@ type upstreamResolver struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ string,
_ net.IP, _ netip.Addr,
_ *net.IPNet, _ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder, hostsDNSHolder *hostsDNSHolder,
domain string, domain string,

View File

@ -4,7 +4,7 @@ package dns
import ( import (
"context" "context"
"net" "net/netip"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -19,8 +19,8 @@ type upstreamResolver struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ string,
_ net.IP, _ netip.Addr,
_ *net.IPNet, _ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,

View File

@ -6,6 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@ -18,16 +19,16 @@ import (
type upstreamResolverIOS struct { type upstreamResolverIOS struct {
*upstreamResolverBase *upstreamResolverBase
lIP net.IP lIP netip.Addr
lNet *net.IPNet lNet netip.Prefix
interfaceName string interfaceName string
} }
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
interfaceName string, interfaceName string,
ip net.IP, ip netip.Addr,
net *net.IPNet, net netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,
@ -58,8 +59,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} }
client.DialTimeout = timeout client.DialTimeout = timeout
upstreamIP := net.ParseIP(upstreamHost) upstreamIP, err := netip.ParseAddr(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
}
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil { if err != nil {
@ -73,7 +77,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS // This method is needed for iOS
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName) index, err := getInterfaceIndex(interfaceName)
if err != nil { if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err) log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
@ -82,7 +86,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{ LocalAddr: &net.UDPAddr{
IP: ip, IP: ip.AsSlice(),
Port: 0, // Let the OS pick a free port Port: 0, // Let the OS pick a free port
}, },
Timeout: dialTimeout, Timeout: dialTimeout,

View File

@ -2,7 +2,7 @@ package dns
import ( import (
"context" "context"
"net" "net/netip"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {

View File

@ -1008,7 +1008,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// apply routes first, route related actions might depend on routing being enabled // apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes()) routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update routes: %v", err)
} }
if e.acl != nil { if e.acl != nil {
@ -1104,7 +1104,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID), ID: route.ID(protoRoute.ID),
Network: prefix, Network: prefix.Masked(),
Domains: domain.FromPunycodeList(protoRoute.Domains), Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID), NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
@ -1138,7 +1138,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
return entries return entries
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
dnsUpdate := nbdns.Config{ dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(), ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0), CustomZones: make([]nbdns.CustomZone, 0),
@ -1790,9 +1790,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
} }
// GetWgAddr returns the wireguard address // GetWgAddr returns the wireguard address
func (e *Engine) GetWgAddr() net.IP { func (e *Engine) GetWgAddr() netip.Addr {
if e.wgInterface == nil { if e.wgInterface == nil {
return nil return netip.Addr{}
} }
return e.wgInterface.Address().IP return e.wgInterface.Address().IP
} }
@ -1861,12 +1861,7 @@ func (e *Engine) Address() (netip.Addr, error) {
return netip.Addr{}, errors.New("wireguard interface not initialized") return netip.Addr{}, errors.New("wireguard interface not initialized")
} }
addr := e.wgInterface.Address() return e.wgInterface.Address().IP, nil
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
}
return ip.Unmap(), nil
} }
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) { func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {

View File

@ -371,11 +371,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {

View File

@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
// fallback if mark rules are not in place // fallback if mark rules are not in place
wgnet := c.iface.Address().Network wgnet := c.iface.Address().Network
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice()) return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
} }
// mapRxPackets maps packet counts to RX based on flow direction // mapRxPackets maps packet counts to RX based on flow direction
@ -293,17 +293,15 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
// fallback if marks are not set // fallback if marks are not set
wgaddr := c.iface.Address().IP wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network wgnetwork := c.iface.Address().Network
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
switch { switch {
case wgaddr.Equal(src): case wgaddr == srcIP:
return nftypes.Egress return nftypes.Egress
case wgaddr.Equal(dst): case wgaddr == dstIP:
return nftypes.Ingress return nftypes.Ingress
case wgnetwork.Contains(src): case wgnetwork.Contains(srcIP):
// netbird network -> resource network // netbird network -> resource network
return nftypes.Ingress return nftypes.Ingress
case wgnetwork.Contains(dst): case wgnetwork.Contains(dstIP):
// resource network -> netbird network // resource network -> netbird network
return nftypes.Egress return nftypes.Egress
} }

View File

@ -2,7 +2,7 @@ package logger
import ( import (
"context" "context"
"net" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -23,17 +23,16 @@ type Logger struct {
rcvChan atomic.Pointer[rcvChan] rcvChan atomic.Pointer[rcvChan]
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgIfaceIPNet net.IPNet wgIfaceNet netip.Prefix
dnsCollection atomic.Bool dnsCollection atomic.Bool
exitNodeCollection atomic.Bool exitNodeCollection atomic.Bool
Store types.Store Store types.Store
} }
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger { func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
return &Logger{ return &Logger{
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgIfaceIPNet: wgIfaceIPNet, wgIfaceNet: wgIfaceIPNet,
Store: store.NewMemoryStore(), Store: store.NewMemoryStore(),
} }
} }
@ -89,11 +88,11 @@ func (l *Logger) startReceiver() {
var isSrcExitNode bool var isSrcExitNode bool
var isDestExitNode bool var isDestExitNode bool
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) { if !l.wgIfaceNet.Contains(event.SourceIP) {
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP) event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
} }
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) { if !l.wgIfaceNet.Contains(event.DestIP) {
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP) event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
} }

View File

@ -1,7 +1,7 @@
package logger_test package logger_test
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
@ -12,7 +12,7 @@ import (
) )
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
logger := logger.New(nil, net.IPNet{}) logger := logger.New(nil, netip.Prefix{})
logger.Enable() logger.Enable()
event := types.EventFields{ event := types.EventFields{

View File

@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net" "net/netip"
"runtime" "runtime"
"sync" "sync"
"time" "time"
@ -34,11 +34,11 @@ type Manager struct {
// NewManager creates a new netflow manager // NewManager creates a new netflow manager
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager { func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
var ipNet net.IPNet var prefix netip.Prefix
if iface != nil { if iface != nil {
ipNet = *iface.Address().Network prefix = iface.Address().Network
} }
flowLogger := logger.New(statusRecorder, ipNet) flowLogger := logger.New(statusRecorder, prefix)
var ct nftypes.ConnTracker var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {

View File

@ -1,7 +1,7 @@
package netflow package netflow
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
@ -33,10 +33,7 @@ func (m *mockIFaceMapper) IsUserspaceBind() bool {
func TestManager_Update(t *testing.T) { func TestManager_Update(t *testing.T) {
mockIFace := &mockIFaceMapper{ mockIFace := &mockIFaceMapper{
address: wgaddr.Address{ address: wgaddr.Address{
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.1/32"),
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
}, },
isUserspaceBind: true, isUserspaceBind: true,
} }
@ -102,10 +99,7 @@ func TestManager_Update(t *testing.T) {
func TestManager_Update_TokenPreservation(t *testing.T) { func TestManager_Update_TokenPreservation(t *testing.T) {
mockIFace := &mockIFaceMapper{ mockIFace := &mockIFaceMapper{
address: wgaddr.Address{ address: wgaddr.Address{
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.1/32"),
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
}, },
isUserspaceBind: true, isUserspaceBind: true,
} }

View File

@ -264,7 +264,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
continue continue
} }
prefix := netip.PrefixFrom(ip, ip.BitLen()) prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
newPrefixes = append(newPrefixes, prefix) newPrefixes = append(newPrefixes, prefix)
} }

View File

@ -333,11 +333,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
var merr *multierror.Error
if !m.disableClientRoutes { if !m.disableClientRoutes {
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil { if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
log.Errorf("Failed to update system routes: %v", err) merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err))
} }
m.updateClientNetworks(updateSerial, filteredClientRoutes) m.updateClientNetworks(updateSerial, filteredClientRoutes)
@ -346,14 +347,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
m.clientRoutes = newClientRoutesIDMap m.clientRoutes = newClientRoutesIDMap
if m.serverRouter == nil { if m.serverRouter == nil {
return nil return nberrors.FormatErrorOrNil(merr)
} }
if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil { if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
return fmt.Errorf("update routes: %w", err) merr = multierror.Append(merr, fmt.Errorf("update server routes: %w", err))
} }
return nil return nberrors.FormatErrorOrNil(merr)
} }
// SetRouteChangeListener set RouteListener for route change Notifier // SetRouteChangeListener set RouteListener for route change Notifier

View File

@ -44,7 +44,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -71,7 +71,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: localPeerKey, Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.252.250/30"), Network: netip.MustParsePrefix("100.64.252.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -99,7 +99,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: localPeerKey, Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"), Network: netip.MustParsePrefix("100.64.30.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -127,7 +127,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: localPeerKey, Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"), Network: netip.MustParsePrefix("100.64.30.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -211,7 +211,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -233,7 +233,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -250,7 +250,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -272,7 +272,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -282,7 +282,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "b", ID: "b",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey2, Peer: remotePeerKey2,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -299,7 +299,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -327,7 +327,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: localPeerKey, Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -356,7 +356,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "l1", ID: "l1",
NetID: "routeA", NetID: "routeA",
Peer: localPeerKey, Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -376,7 +376,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
ID: "r1", ID: "r1",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"), Network: netip.MustParsePrefix("100.64.251.248/30"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
@ -440,11 +440,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
} }
if len(testCase.inputInitRoutes) > 0 { if len(testCase.inputInitRoutes) > 0 {
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false) err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
require.NoError(t, err, "should update routes with init routes") require.NoError(t, err, "should update routes with init routes")
} }
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false) err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
require.NoError(t, err, "should update routes") require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected expectedWatchers := testCase.clientNetworkWatchersExpected

View File

@ -13,7 +13,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
const ( const (
@ -22,8 +22,13 @@ const (
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark" srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
) )
type iface interface {
Address() wgaddr.Address
Name() string
}
// Setup configures sysctl settings for RP filtering and source validation. // Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface iface.WGIface) (map[string]int, error) { func Setup(wgIface iface) (map[string]int, error) {
keys := map[string]int{} keys := map[string]int{}
var result *multierror.Error var result *multierror.Error

View File

@ -6,9 +6,10 @@ import (
"net/netip" "net/netip"
"sync" "sync"
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
) )
type Nexthop struct { type Nexthop struct {
@ -30,11 +31,16 @@ func (n Nexthop) String() string {
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name) return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
} }
type wgIface interface {
Address() wgaddr.Address
Name() string
}
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct { type SysOps struct {
refCounter *ExclusionCounter refCounter *ExclusionCounter
wgInterface iface.WGIface wgInterface wgIface
// prefixes is tracking all the current added prefixes im memory // prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update) // (this is used in iOS as all route updates require a full table update)
//nolint //nolint
@ -45,9 +51,27 @@ type SysOps struct {
notifier *notifier.Notifier notifier *notifier.Notifier
} }
func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps { func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{ return &SysOps{
wgInterface: wgInterface, wgInterface: wgInterface,
notifier: notifier, notifier: notifier,
} }
} }
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
addr := prefix.Addr()
switch {
case
!addr.IsValid(),
addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsMulticast(),
addr.IsUnspecified() && prefix.Bits() != 0,
r.wgInterface.Address().Network.Contains(addr):
return vars.ErrRouteNotAllowed
}
return nil
}

View File

@ -8,6 +8,8 @@ import (
"net/netip" "net/netip"
"os/exec" "os/exec"
"regexp" "regexp"
"runtime"
"strings"
"sync" "sync"
"testing" "testing"
@ -33,7 +35,12 @@ func init() {
func TestConcurrentRoutes(t *testing.T) { func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0") baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := NewSysOps(nil, nil) r := NewSysOps(nil, nil)
@ -43,7 +50,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) { go func(ip netip.Addr) {
defer wg.Done() defer wg.Done()
prefix := netip.PrefixFrom(ip, 32) prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil { if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err) t.Errorf("Failed to add route for %s: %v", prefix, err)
} }
}(baseIP) }(baseIP)
@ -59,7 +66,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) { go func(ip netip.Addr) {
defer wg.Done() defer wg.Done()
prefix := netip.PrefixFrom(ip, 32) prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil { if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err) t.Errorf("Failed to remove route for %s: %v", prefix, err)
} }
}(baseIP) }(baseIP)
@ -119,18 +126,39 @@ func TestBits(t *testing.T) {
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper() t.Helper()
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() if runtime.GOOS == "darwin" {
require.NoError(t, err, "Failed to create loopback alias") err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := NewSysOps(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() { t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove loopback alias") assert.NoError(t, err, "Failed to remove route from table")
}) })
return "lo0" return intf
} }
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper() t.Helper()
var originalNexthop net.IP var originalNexthop net.IP
@ -176,12 +204,40 @@ func fetchOriginalGateway() (net.IP, error) {
return net.ParseIP(matches[1]), nil return net.ParseIP(matches[1]), nil
} }
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) { func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper() t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
} }

View File

@ -17,7 +17,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
@ -106,59 +105,15 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
return nil return nil
} }
// TODO: fix: for default our wg address now appears as the default gw
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
nexthop, err := GetNextHop(addr)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(nexthop.IP) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
if nexthop.IP.Is6() {
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
nexthop, err = GetNextHop(nexthop.IP)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
return r.addToRouteTable(gatewayPrefix, nexthop)
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values. // If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) { func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr() if err := r.validateRoute(prefix); err != nil {
switch { return Nexthop{}, err
case addr.IsLoopback(), }
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
addr := prefix.Addr()
if addr.IsUnspecified() {
return Nexthop{}, vars.ErrRouteNotAllowed return Nexthop{}, vars.ErrRouteNotAllowed
} }
@ -179,10 +134,7 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface
Intf: nexthop.Intf, Intf: nexthop.Intf,
} }
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) vpnAddr := vpnIntf.Address().IP
if !ok {
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
@ -271,32 +223,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
return nil return nil
} }
return r.addNonExistingRoute(prefix, intf) return r.addToRouteTable(prefix, nextHop)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
} }
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, // genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
@ -408,12 +335,8 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil { if gateway == nil {
if runtime.GOOS == "freebsd" {
return Nexthop{Intf: intf}, nil
}
if preferredSrc == nil { if preferredSrc == nil {
return Nexthop{}, vars.ErrRouteNotFound return Nexthop{Intf: intf}, nil
} }
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc) log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
@ -457,32 +380,6 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
return addr.Unmap(), nil return addr.Unmap(), nil
} }
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. // IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting() localRoutes, err := hasSeparateRouting()

View File

@ -3,23 +3,25 @@
package systemops package systemops
import ( import (
"bytes"
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os" "os/exec"
"runtime" "runtime"
"strconv"
"strings" "strings"
"syscall"
"testing" "testing"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
) )
type dialer interface { type dialer interface {
@ -27,105 +29,370 @@ type dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error) DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }
func TestAddRemoveRoutes(t *testing.T) { func TestAddVPNRoute(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
prefix netip.Prefix prefix netip.Prefix
shouldRouteToWireguard bool expectError bool
shouldBeRemoved bool
}{ }{
{ {
name: "Should Add And Remove Route 100.66.120.0/24", name: "IPv4 - Private network route",
prefix: netip.MustParsePrefix("100.66.120.0/24"), prefix: netip.MustParsePrefix("10.10.100.0/24"),
shouldRouteToWireguard: true,
shouldBeRemoved: true,
}, },
{ {
name: "Should Not Add Or Remove Route 127.0.0.1/32", name: "IPv4 Single host",
prefix: netip.MustParsePrefix("127.0.0.1/32"), prefix: netip.MustParsePrefix("10.111.111.111/32"),
shouldRouteToWireguard: false, },
shouldBeRemoved: false, {
name: "IPv4 RFC3927 test range",
prefix: netip.MustParsePrefix("198.51.100.0/24"),
},
{
name: "IPv4 Default route",
prefix: netip.MustParsePrefix("0.0.0.0/0"),
},
{
name: "IPv6 Subnet",
prefix: netip.MustParsePrefix("fdb1:848a:7e16::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("fdb1:848a:7e16:a::b/128"),
},
{
name: "IPv6 Default route",
prefix: netip.MustParsePrefix("::/0"),
},
// IPv4 addresses that should be rejected (matches validateRoute logic)
{
name: "IPv4 Loopback",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
expectError: true,
},
{
name: "IPv4 Link-local unicast",
prefix: netip.MustParsePrefix("169.254.1.1/32"),
expectError: true,
},
{
name: "IPv4 Link-local multicast",
prefix: netip.MustParsePrefix("224.0.0.251/32"),
expectError: true,
},
{
name: "IPv4 Multicast",
prefix: netip.MustParsePrefix("239.255.255.250/32"),
expectError: true,
},
{
name: "IPv4 Unspecified with prefix",
prefix: netip.MustParsePrefix("0.0.0.0/32"),
expectError: true,
},
// IPv6 addresses that should be rejected (matches validateRoute logic)
{
name: "IPv6 Loopback",
prefix: netip.MustParsePrefix("::1/128"),
expectError: true,
},
{
name: "IPv6 Link-local unicast",
prefix: netip.MustParsePrefix("fe80::1/128"),
expectError: true,
},
{
name: "IPv6 Link-local multicast",
prefix: netip.MustParsePrefix("ff02::1/128"),
expectError: true,
},
{
name: "IPv6 Interface-local multicast",
prefix: netip.MustParsePrefix("ff01::1/128"),
expectError: true,
},
{
name: "IPv6 Multicast",
prefix: netip.MustParsePrefix("ff00::1/128"),
expectError: true,
},
{
name: "IPv6 Unspecified with prefix",
prefix: netip.MustParsePrefix("::/128"),
expectError: true,
},
{
name: "IPv4 WireGuard interface network overlap",
prefix: netip.MustParsePrefix("100.65.75.0/24"),
expectError: true,
},
{
name: "IPv4 WireGuard interface network subnet",
prefix: netip.MustParsePrefix("100.65.75.0/32"),
expectError: true,
}, },
} }
for n, testCase := range testCases { for n, testCase := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", testCase.name, " on freebsd")
}
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey() wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun53%d", n),
Address: "100.65.75.2/24",
WGPrivKey: peerPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgInterface, err := iface.NewWGIFace(opts)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
_, _, err = r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
}) })
index, err := net.InterfaceByName(wgInterface.Name()) intf, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err") require.NoError(t, err)
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
// add the route
err = r.AddVPNRoute(testCase.prefix, intf) err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err") if testCase.expectError {
assert.ErrorIs(t, err, vars.ErrRouteNotAllowed)
return
}
if testCase.shouldRouteToWireguard { // validate it's pointing to the WireGuard interface
assertWGOutInterface(t, testCase.prefix, wgInterface, false) require.NoError(t, err)
nextHop := getNextHop(t, testCase.prefix.Addr())
assert.Equal(t, wgInterface.Name(), nextHop.Intf.Name, "next hop interface should be WireGuard interface")
// remove route again
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err)
// validate it's gone
nextHop, err = GetNextHop(testCase.prefix.Addr())
require.True(t,
errors.Is(err, vars.ErrRouteNotFound) || err == nil && nextHop.Intf != nil && nextHop.Intf.Name != wgInterface.Name(),
"err: %v, next hop: %v", err, nextHop)
})
}
}
func getNextHop(t *testing.T, addr netip.Addr) Nexthop {
t.Helper()
if runtime.GOOS == "windows" || runtime.GOOS == "linux" {
nextHop, err := GetNextHop(addr)
if runtime.GOOS == "windows" && errors.Is(err, vars.ErrRouteNotFound) && addr.Is6() {
// TODO: Fix this test. It doesn't return the route when running in a windows github runner, but it is
// present in the route table.
t.Skip("Skipping windows test")
}
require.NoError(t, err)
require.NotNil(t, nextHop.Intf, "next hop interface should not be nil for %s", addr)
return nextHop
}
// GetNextHop for bsd is buggy and returns the wrong interface for the default route.
if addr.IsUnspecified() {
// On macOS, querying 0.0.0.0 returns the wrong interface
if addr.Is4() {
addr = netip.MustParseAddr("1.2.3.4")
} else {
addr = netip.MustParseAddr("2001:db8::1")
}
}
cmd := exec.Command("route", "-n", "get", addr.String())
if addr.Is6() {
cmd = exec.Command("route", "-n", "get", "-inet6", addr.String())
}
output, err := cmd.CombinedOutput()
t.Logf("route output: %s", output)
require.NoError(t, err, "%s failed")
lines := strings.Split(string(output), "\n")
var intf string
var gateway string
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "interface:") {
intf = strings.TrimSpace(strings.TrimPrefix(line, "interface:"))
} else if strings.HasPrefix(line, "gateway:") {
gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:"))
}
}
require.NotEmpty(t, intf, "interface should be found in route output")
iface, err := net.InterfaceByName(intf)
require.NoError(t, err, "interface %s should exist", intf)
nexthop := Nexthop{Intf: iface}
if gateway != "" && gateway != "link#"+strconv.Itoa(iface.Index) {
addr, err := netip.ParseAddr(gateway)
if err == nil {
nexthop.IP = addr
}
}
return nexthop
}
func TestAddRouteToNonVPNIntf(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
expectError bool
errorType error
}{
{
name: "IPv4 RFC3927 test range",
prefix: netip.MustParsePrefix("198.51.100.0/24"),
},
{
name: "IPv4 Single host",
prefix: netip.MustParsePrefix("8.8.8.8/32"),
},
{
name: "IPv6 External network route",
prefix: netip.MustParsePrefix("2001:db8:1000::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("2001:db8::1/128"),
},
{
name: "IPv6 Subnet",
prefix: netip.MustParsePrefix("2a05:d014:1f8d::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("2a05:d014:1f8d:7302:ebca:ec15:b24d:d07e/128"),
},
// Addresses that should be rejected
{
name: "IPv4 Loopback",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Link-local unicast",
prefix: netip.MustParsePrefix("169.254.1.1/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Multicast",
prefix: netip.MustParsePrefix("239.255.255.250/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Unspecified",
prefix: netip.MustParsePrefix("0.0.0.0/0"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Loopback",
prefix: netip.MustParsePrefix("::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Link-local unicast",
prefix: netip.MustParsePrefix("fe80::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Multicast",
prefix: netip.MustParsePrefix("ff00::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Unspecified",
prefix: netip.MustParsePrefix("::/0"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 WireGuard interface network overlap",
prefix: netip.MustParsePrefix("100.65.75.0/24"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
require.NoError(t, err, "Should be able to get IPv4 default route")
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if testCase.prefix.Addr().Is6() &&
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
t.Skip("Skipping test as no ipv6 default route is available")
}
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
t.Fatalf("Failed to get IPv6 default route: %v", err)
}
var initialNextHop Nexthop
if testCase.prefix.Addr().Is6() {
initialNextHop = initialNextHopV6
} else { } else {
assertWGOutInterface(t, testCase.prefix, wgInterface, true) initialNextHop = initialNextHopV4
} }
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixNexthop, err := GetNextHop(testCase.prefix.Addr()) nexthop, err := r.addRouteToNonVPNIntf(testCase.prefix, wgInterface, initialNextHop)
require.NoError(t, err, "GetNextHop should not return err")
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) if testCase.expectError {
require.NoError(t, err) require.ErrorIs(t, err, vars.ErrRouteNotAllowed)
return
if testCase.shouldBeRemoved {
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
} else {
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
}
} }
require.NoError(t, err)
t.Logf("Next hop for %s: %s", testCase.prefix, nexthop)
// Verify the route was added and points to non-VPN interface
currentNextHop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err)
assert.NotEqual(t, wgInterface.Name(), currentNextHop.Intf.Name, "Route should not point to VPN interface")
err = r.removeFromRouteTable(testCase.prefix, nexthop)
assert.NoError(t, err)
}) })
} }
} }
func TestGetNextHop(t *testing.T) { func TestGetNextHop(t *testing.T) {
if runtime.GOOS == "freebsd" { defaultNh, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Skip("skipping on freebsd")
}
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil { if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err) t.Fatal("shouldn't return error when fetching the gateway: ", err)
} }
if !nexthop.IP.IsValid() { if !defaultNh.IP.IsValid() {
t.Fatal("should return a gateway") t.Fatal("should return a gateway")
} }
addresses, err := net.InterfaceAddrs() addresses, err := net.InterfaceAddrs()
@ -133,7 +400,6 @@ func TestGetNextHop(t *testing.T) {
t.Fatal("shouldn't return error when fetching interface addresses: ", err) t.Fatal("shouldn't return error when fetching interface addresses: ", err)
} }
var testingIP string
var testingPrefix netip.Prefix var testingPrefix netip.Prefix
for _, address := range addresses { for _, address := range addresses {
if address.Network() != "ip+net" { if address.Network() != "ip+net" {
@ -141,213 +407,23 @@ func TestGetNextHop(t *testing.T) {
} }
prefix := netip.MustParsePrefix(address.String()) prefix := netip.MustParsePrefix(address.String())
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() { if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
testingIP = prefix.Addr().String()
testingPrefix = prefix.Masked() testingPrefix = prefix.Masked()
break break
} }
} }
localIP, err := GetNextHop(testingPrefix.Addr()) nh, err := GetNextHop(testingPrefix.Addr())
if err != nil { if err != nil {
t.Fatal("shouldn't return error: ", err) t.Fatal("shouldn't return error: ", err)
} }
if !localIP.IP.IsValid() { if nh.Intf == nil {
t.Fatal("should return a gateway for local network") t.Fatal("should return a gateway for local network")
} }
if localIP.IP.String() == nexthop.IP.String() { if nh.IP.String() == defaultNh.IP.String() {
t.Fatal("local IP should not match with gateway IP") t.Fatal("next hop IP should not match with default gateway IP")
} }
if localIP.IP.String() != testingIP { if nh.Intf.Name != defaultNh.Intf.Name {
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String()) t.Fatalf("next hop interface name should match with default gateway interface name, got: %s, want: %s", nh.Intf.Name, defaultNh.Intf.Name)
}
}
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultNexthop: ", defaultNexthop)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
testCases := []struct {
name string
prefix netip.Prefix
preExistingPrefix netip.Prefix
shouldAddRoute bool
}{
{
name: "Should Add And Remove random Route",
prefix: netip.MustParsePrefix("99.99.99.99/32"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
shouldAddRoute: false,
},
{
name: "Should Add Route if bigger network exists",
prefix: netip.MustParsePrefix("100.100.100.0/24"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: true,
},
{
name: "Should Add Route if smaller network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if same network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: false,
},
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun53%d", n),
Address: "100.65.75.2/24",
WGPort: 33100,
WGPrivKey: peerPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgInterface, err := iface.NewWGIFace(opts)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface, nil)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
// test if route exists after adding
ok, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "should not return err")
require.True(t, ok, "route should exist")
// remove route again if added
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err")
}
// route should either not have been added or should have been removed
// In case of already existing route, it should not have been added (but still exist)
ok, err := existsInRouteTable(testCase.prefix)
t.Log("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err")
if !strings.Contains(buf.String(), "because it already exists") {
require.False(t, ok, "route should not exist")
}
})
}
}
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
switch {
case p.Addr().Is6():
continue
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
continue
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
continue
// FreeBSD loopback 127/8 is not added to the routing table
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
continue
default:
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
} }
} }
@ -384,11 +460,16 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) { func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
t.Helper() t.Helper()
err := r.AddVPNRoute(prefix, intf) if err := r.AddVPNRoute(prefix, intf); err != nil {
require.NoError(t, err, "addVPNRoute should not return err") if !errors.Is(err, syscall.EEXIST) && !errors.Is(err, vars.ErrRouteNotAllowed) {
t.Fatalf("addVPNRoute should not return err: %v", err)
}
t.Logf("addVPNRoute %v returned: %v", prefix, err)
}
t.Cleanup(func() { t.Cleanup(func() {
err = r.RemoveVPNRoute(prefix, intf) if err := r.RemoveVPNRoute(prefix, intf); err != nil && !errors.Is(err, vars.ErrRouteNotAllowed) {
assert.NoError(t, err, "removeVPNRoute should not return err") t.Fatalf("removeVPNRoute should not return err: %v", err)
}
}) })
} }
@ -422,28 +503,10 @@ func setupTestEnv(t *testing.T) {
// 10.10.0.0/24 more specific route exists in vpn table // 10.10.0.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf) setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
// 127.0.10.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
// unique route in vpn table // unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf) setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
} }
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
return
}
prefixNexthop, err := GetNextHop(prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
}
}
func TestIsVpnRoute(t *testing.T) { func TestIsVpnRoute(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@ -149,6 +149,10 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
} }
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
if !nbnet.AdvancedRouting() { if !nbnet.AdvancedRouting() {
return r.genericAddVPNRoute(prefix, intf) return r.genericAddVPNRoute(prefix, intf)
} }
@ -172,6 +176,10 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
} }
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
if !nbnet.AdvancedRouting() { if !nbnet.AdvancedRouting() {
return r.genericRemoveVPNRoute(prefix, intf) return r.genericRemoveVPNRoute(prefix, intf)
} }
@ -219,7 +227,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
ones, _ := route.Dst.Mask.Size() ones, _ := route.Dst.Mask.Size()
prefix := netip.PrefixFrom(addr, ones) prefix := netip.PrefixFrom(addr.Unmap(), ones)
if prefix.IsValid() { if prefix.IsValid() {
prefixList = append(prefixList, prefix) prefixList = append(prefixList, prefix)
} }
@ -247,7 +255,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
return fmt.Errorf("add gateway and device: %w", err) return fmt.Errorf("add gateway and device: %w", err)
} }
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
return fmt.Errorf("netlink add route: %w", err) return fmt.Errorf("netlink add route: %w", err)
} }
@ -270,7 +278,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
Dst: ipNet, Dst: ipNet,
} }
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
return fmt.Errorf("netlink add unreachable route: %w", err) return fmt.Errorf("netlink add unreachable route: %w", err)
} }

View File

@ -19,7 +19,6 @@ import (
) )
var expectedVPNint = "wgtest0" var expectedVPNint = "wgtest0"
var expectedLoopbackInt = "lo"
var expectedExternalInt = "dummyext0" var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0" var expectedInternalInt = "dummyint0"
@ -31,12 +30,6 @@ func init() {
dialer: &net.Dialer{}, dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
}, },
{
name: "To more specific route (local) without custom dialer via physical interface",
expectedInterface: expectedLoopbackInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}...) }...)
} }

View File

@ -11,10 +11,16 @@ import (
) )
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
return r.genericAddVPNRoute(prefix, intf) return r.genericAddVPNRoute(prefix, intf)
} }
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
return r.genericRemoveVPNRoute(prefix, intf) return r.genericRemoveVPNRoute(prefix, intf)
} }

View File

@ -0,0 +1,268 @@
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type mockWGIface struct {
address wgaddr.Address
name string
}
func (m *mockWGIface) Address() wgaddr.Address {
return m.address
}
func (m *mockWGIface) Name() string {
return m.name
}
func TestSysOps_validateRoute(t *testing.T) {
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wg0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
tests := []struct {
name string
prefix string
expectError bool
}{
// Valid routes
{
name: "valid IPv4 route",
prefix: "192.168.1.0/24",
expectError: false,
},
{
name: "valid IPv6 route",
prefix: "2001:db8::/32",
expectError: false,
},
{
name: "valid single IPv4 host",
prefix: "8.8.8.8/32",
expectError: false,
},
{
name: "valid single IPv6 host",
prefix: "2001:4860:4860::8888/128",
expectError: false,
},
// Invalid routes - loopback
{
name: "IPv4 loopback",
prefix: "127.0.0.1/32",
expectError: true,
},
{
name: "IPv6 loopback",
prefix: "::1/128",
expectError: true,
},
// Invalid routes - link-local unicast
{
name: "IPv4 link-local unicast",
prefix: "169.254.1.1/32",
expectError: true,
},
{
name: "IPv6 link-local unicast",
prefix: "fe80::1/128",
expectError: true,
},
// Invalid routes - multicast
{
name: "IPv4 multicast",
prefix: "224.0.0.1/32",
expectError: true,
},
{
name: "IPv6 multicast",
prefix: "ff02::1/128",
expectError: true,
},
// Invalid routes - link-local multicast
{
name: "IPv4 link-local multicast",
prefix: "224.0.0.0/24",
expectError: true,
},
{
name: "IPv6 link-local multicast",
prefix: "ff02::/16",
expectError: true,
},
// Invalid routes - interface-local multicast (IPv6 only)
{
name: "IPv6 interface-local multicast",
prefix: "ff01::1/128",
expectError: true,
},
// Invalid routes - overlaps with WG interface network
{
name: "overlaps with WG network - exact match",
prefix: "10.0.0.0/24",
expectError: true,
},
{
name: "overlaps with WG network - subset",
prefix: "10.0.0.1/32",
expectError: true,
},
{
name: "overlaps with WG network - host in range",
prefix: "10.0.0.100/32",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefix, err := netip.ParsePrefix(tt.prefix)
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
err = sysOps.validateRoute(prefix)
if tt.expectError {
require.Error(t, err, "validateRoute() expected error for %s", tt.prefix)
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s", tt.prefix)
} else {
assert.NoError(t, err, "validateRoute() expected no error for %s", tt.prefix)
}
})
}
}
func TestSysOps_validateRoute_SubnetOverlap(t *testing.T) {
wgNetwork := netip.MustParsePrefix("192.168.100.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wg0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
tests := []struct {
name string
prefix string
expectError bool
description string
}{
{
name: "identical subnet",
prefix: "192.168.100.0/24",
expectError: true,
description: "exact same network as WG interface",
},
{
name: "broader subnet containing WG network",
prefix: "192.168.0.0/16",
expectError: false,
description: "broader network that contains WG network should be allowed",
},
{
name: "host within WG network",
prefix: "192.168.100.50/32",
expectError: true,
description: "specific host within WG network",
},
{
name: "subnet within WG network",
prefix: "192.168.100.128/25",
expectError: true,
description: "smaller subnet within WG network",
},
{
name: "adjacent subnet - same /23",
prefix: "192.168.101.0/24",
expectError: false,
description: "adjacent subnet, no overlap",
},
{
name: "adjacent subnet - different /16",
prefix: "192.167.100.0/24",
expectError: false,
description: "different network, no overlap",
},
{
name: "WG network broadcast address",
prefix: "192.168.100.255/32",
expectError: true,
description: "broadcast address of WG network",
},
{
name: "WG network first usable",
prefix: "192.168.100.1/32",
expectError: true,
description: "first usable address in WG network",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefix, err := netip.ParsePrefix(tt.prefix)
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
err = sysOps.validateRoute(prefix)
if tt.expectError {
require.Error(t, err, "validateRoute() expected error for %s (%s)", tt.prefix, tt.description)
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s (%s)", tt.prefix, tt.description)
} else {
assert.NoError(t, err, "validateRoute() expected no error for %s (%s)", tt.prefix, tt.description)
}
})
}
}
func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wt0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
var invalidPrefix netip.Prefix
err := sysOps.validateRoute(invalidPrefix)
require.Error(t, err, "validateRoute() expected error for invalid prefix")
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for invalid prefix")
}

View File

@ -3,15 +3,19 @@
package systemops package systemops
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os/exec" "strconv"
"strings" "syscall"
"time" "time"
"unsafe"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@ -26,48 +30,16 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
} }
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("add", prefix, nexthop) return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
} }
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("delete", prefix, nexthop) return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
} }
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
inet := "-inet" if !prefix.IsValid() {
if prefix.Addr().Is6() { return fmt.Errorf("invalid prefix: %s", prefix)
inet = "-inet6"
}
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
args := []string{"-n", action, inet, network}
if nexthop.IP.IsValid() {
args = append(args, nexthop.IP.Unmap().String())
} else if nexthop.Intf != nil {
args = append(args, "-interface", nexthop.Intf.Name)
}
if err := retryRouteCmd(args); err != nil {
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
}
return nil
}
func retryRouteCmd(args []string) error {
operation := func() error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
// https://github.com/golang/go/issues/45736
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
} }
expBackOff := backoff.NewExponentialBackOff() expBackOff := backoff.NewExponentialBackOff()
@ -75,9 +47,157 @@ func retryRouteCmd(args []string) error {
expBackOff.MaxInterval = 500 * time.Millisecond expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff) if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
if err != nil { a := "add"
return fmt.Errorf("route cmd retry failed: %w", err) if action == unix.RTM_DELETE {
a = "remove"
}
return fmt.Errorf("%s route for %s: %w", a, prefix, err)
} }
return nil return nil
} }
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
operation := func() error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
}
}()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
}
msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
}
if _, err = unix.Write(fd, msgBytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
}
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil
}
return operation
}
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{
Type: action,
Flags: unix.RTF_UP,
Version: unix.RTM_VERSION,
Seq: 1,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
addrs[unix.RTAX_DST], err = addrToRouteAddr(prefix.Addr())
if err != nil {
return nil, fmt.Errorf("build destination address for %s: %w", prefix.Addr(), err)
}
if prefix.IsSingleIP() {
msg.Flags |= unix.RTF_HOST
} else {
addrs[unix.RTAX_NETMASK], err = prefixToRouteNetmask(prefix)
if err != nil {
return nil, fmt.Errorf("build netmask for %s: %w", prefix, err)
}
}
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
addrs[unix.RTAX_GATEWAY], err = addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return nil, fmt.Errorf("build gateway IP address for %s: %w", nexthop.IP, err)
}
} else if nexthop.Intf != nil {
msg.Index = nexthop.Intf.Index
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return msg, nil
}
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() {
return &route.Inet4Addr{IP: addr.As4()}, nil
}
if addr.Zone() == "" {
return &route.Inet6Addr{IP: addr.As16()}, nil
}
var zone int
// zone can be either a numeric zone ID or an interface name.
if z, err := strconv.Atoi(addr.Zone()); err == nil {
zone = z
} else {
iface, err := net.InterfaceByName(addr.Zone())
if err != nil {
return nil, fmt.Errorf("resolve zone '%s': %w", addr.Zone(), err)
}
zone = iface.Index
}
return &route.Inet6Addr{IP: addr.As16(), ZoneID: zone}, nil
}
func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
bits := prefix.Bits()
if prefix.Addr().Is4() {
m := net.CIDRMask(bits, 32)
var maskBytes [4]byte
copy(maskBytes[:], m)
return &route.Inet4Addr{IP: maskBytes}, nil
}
if prefix.Addr().Is6() {
m := net.CIDRMask(bits, 128)
var maskBytes [16]byte
copy(maskBytes[:], m)
return &route.Inet6Addr{IP: maskBytes}, nil
}
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
}

View File

@ -1,5 +1,3 @@
//go:build windows
package systemops package systemops
import ( import (
@ -9,9 +7,8 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"os/exec" "runtime/debug"
"strconv" "strconv"
"strings"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@ -21,11 +18,12 @@ import (
"github.com/yusufpapurcu/wmi" "github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const InfiniteLifetime = 0xffffffff
type RouteUpdateType int type RouteUpdateType int
// RouteUpdate represents a change in the routing table. // RouteUpdate represents a change in the routing table.
@ -58,9 +56,13 @@ type MSFT_NetRoute struct {
AddressFamily uint16 AddressFamily uint16
} }
// MIB_IPFORWARD_ROW2 is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2 // luid represents a locally unique identifier for network interfaces
type luid uint64
// MIB_IPFORWARD_ROW2 represents a route entry in the routing table.
// It is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
type MIB_IPFORWARD_ROW2 struct { type MIB_IPFORWARD_ROW2 struct {
InterfaceLuid uint64 InterfaceLuid luid
InterfaceIndex uint32 InterfaceIndex uint32
DestinationPrefix IP_ADDRESS_PREFIX DestinationPrefix IP_ADDRESS_PREFIX
NextHop SOCKADDR_INET_NEXTHOP NextHop SOCKADDR_INET_NEXTHOP
@ -108,9 +110,14 @@ type SOCKADDR_INET_NEXTHOP struct {
type MIB_NOTIFICATION_TYPE int32 type MIB_NOTIFICATION_TYPE int32
var ( var (
modiphlpapi = windows.NewLazyDLL("iphlpapi.dll") modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2") procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2") procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2")
procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2")
procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2")
procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry")
procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid")
prefixList []netip.Prefix prefixList []netip.Prefix
lastUpdate time.Time lastUpdate time.Time
@ -139,6 +146,8 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
} }
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
log.Debugf("Adding route to %s via %s", prefix, nexthop)
// if we don't have an interface but a zone, extract the interface index from the zone
if nexthop.IP.Zone() != "" && nexthop.Intf == nil { if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
zone, err := strconv.Atoi(nexthop.IP.Zone()) zone, err := strconv.Atoi(nexthop.IP.Zone())
if err != nil { if err != nil {
@ -147,23 +156,187 @@ func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
nexthop.Intf = &net.Interface{Index: zone} nexthop.Intf = &net.Interface{Index: zone}
} }
return addRouteCmd(prefix, nexthop) return addRoute(prefix, nexthop)
} }
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"delete", prefix.String()} log.Debugf("Removing route to %s via %s", prefix, nexthop)
if nexthop.IP.IsValid() { return deleteRoute(prefix, nexthop)
ip := nexthop.IP.WithZone("") }
args = append(args, ip.Unmap().String())
// setupRouteEntry prepares a route entry with common configuration
func setupRouteEntry(prefix netip.Prefix, nexthop Nexthop) (*MIB_IPFORWARD_ROW2, error) {
route := &MIB_IPFORWARD_ROW2{}
initializeIPForwardEntry(route)
// Convert interface index to luid if interface is specified
if nexthop.Intf != nil {
var luid luid
if err := convertInterfaceIndexToLUID(uint32(nexthop.Intf.Index), &luid); err != nil {
return nil, fmt.Errorf("convert interface index to luid: %w", err)
}
route.InterfaceLuid = luid
route.InterfaceIndex = uint32(nexthop.Intf.Index)
} }
routeCmd := uspfilter.GetSystem32Command("route") if err := setDestinationPrefix(&route.DestinationPrefix, prefix); err != nil {
return nil, fmt.Errorf("set destination prefix: %w", err)
}
out, err := exec.Command(routeCmd, args...).CombinedOutput() if nexthop.IP.IsValid() {
log.Tracef("route %s: %s", strings.Join(args, " "), out) if err := setNextHop(&route.NextHop, nexthop.IP); err != nil {
return nil, fmt.Errorf("set next hop: %w", err)
}
}
if err != nil { return route, nil
return fmt.Errorf("remove route: %w", err) }
// addRoute adds a route using Windows iphelper APIs
func addRoute(prefix netip.Prefix, nexthop Nexthop) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in addRoute: %v, stack trace: %s", r, debug.Stack())
}
}()
route, setupErr := setupRouteEntry(prefix, nexthop)
if setupErr != nil {
return fmt.Errorf("setup route entry: %w", setupErr)
}
route.Metric = 1
route.ValidLifetime = InfiniteLifetime
route.PreferredLifetime = InfiniteLifetime
return createIPForwardEntry2(route)
}
// deleteRoute deletes a route using Windows iphelper APIs
func deleteRoute(prefix netip.Prefix, nexthop Nexthop) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in deleteRoute: %v, stack trace: %s", r, debug.Stack())
}
}()
route, setupErr := setupRouteEntry(prefix, nexthop)
if setupErr != nil {
return fmt.Errorf("setup route entry: %w", setupErr)
}
if err := getIPForwardEntry2(route); err != nil {
return fmt.Errorf("get route entry: %w", err)
}
return deleteIPForwardEntry2(route)
}
// setDestinationPrefix sets the destination prefix in the route structure
func setDestinationPrefix(prefix *IP_ADDRESS_PREFIX, dest netip.Prefix) error {
addr := dest.Addr()
prefix.PrefixLength = uint8(dest.Bits())
if addr.Is4() {
prefix.Prefix.sin6_family = windows.AF_INET
ip4 := addr.As4()
binary.BigEndian.PutUint32(prefix.Prefix.data[:4],
uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3]))
return nil
}
if addr.Is6() {
prefix.Prefix.sin6_family = windows.AF_INET6
ip6 := addr.As16()
copy(prefix.Prefix.data[4:20], ip6[:])
if zone := addr.Zone(); zone != "" {
if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil {
binary.BigEndian.PutUint32(prefix.Prefix.data[20:24], uint32(scopeID))
}
}
return nil
}
return fmt.Errorf("invalid address family")
}
// setNextHop sets the next hop address in the route structure
func setNextHop(nextHop *SOCKADDR_INET_NEXTHOP, addr netip.Addr) error {
if addr.Is4() {
nextHop.sin6_family = windows.AF_INET
ip4 := addr.As4()
binary.BigEndian.PutUint32(nextHop.data[:4],
uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3]))
return nil
}
if addr.Is6() {
nextHop.sin6_family = windows.AF_INET6
ip6 := addr.As16()
copy(nextHop.data[4:20], ip6[:])
// Handle zone if present
if zone := addr.Zone(); zone != "" {
if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil {
binary.BigEndian.PutUint32(nextHop.data[20:24], uint32(scopeID))
}
}
return nil
}
return fmt.Errorf("invalid address family")
}
// Windows API wrappers
func createIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
r1, _, e1 := syscall.SyscallN(procCreateIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("CreateIpForwardEntry2: %w", e1)
}
return fmt.Errorf("CreateIpForwardEntry2: code %d", r1)
}
return nil
}
func deleteIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
r1, _, e1 := syscall.SyscallN(procDeleteIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("DeleteIpForwardEntry2: %w", e1)
}
return fmt.Errorf("DeleteIpForwardEntry2: code %d", r1)
}
return nil
}
func getIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
r1, _, e1 := syscall.SyscallN(procGetIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("GetIpForwardEntry2: %w", e1)
}
return fmt.Errorf("GetIpForwardEntry2: code %d", r1)
}
return nil
}
// https://learn.microsoft.com/en-us/windows/win32/api/netioapi/nf-netioapi-initializeipforwardentry
func initializeIPForwardEntry(route *MIB_IPFORWARD_ROW2) {
// Does not return anything. Trying to handle the error might return an uninitialized value.
_, _, _ = syscall.SyscallN(procInitializeIpForwardEntry.Addr(), uintptr(unsafe.Pointer(route)))
}
func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *luid) error {
r1, _, e1 := syscall.SyscallN(procConvertInterfaceIndexToLuid.Addr(),
uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("ConvertInterfaceIndexToLuid: %w", e1)
}
return fmt.Errorf("ConvertInterfaceIndexToLuid: code %d", r1)
} }
return nil return nil
} }
@ -319,7 +492,7 @@ func cancelMibChangeNotify2(handle windows.Handle) error {
} }
// GetRoutesFromTable returns the current routing table from with prefixes only. // GetRoutesFromTable returns the current routing table from with prefixes only.
// It ccaches the result for 2 seconds to avoid blocking the caller. // It caches the result for 2 seconds to avoid blocking the caller.
func GetRoutesFromTable() ([]netip.Prefix, error) { func GetRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock() mux.Lock()
defer mux.Unlock() defer mux.Unlock()
@ -388,35 +561,6 @@ func GetRoutes() ([]Route, error) {
return routes, nil return routes, nil
} }
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"add", prefix.String()}
if nexthop.IP.IsValid() {
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
addr = "::"
}
args = append(args, addr)
}
if nexthop.Intf != nil {
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
}
return nil
}
func isCacheDisabled() bool { func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true" return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
} }

View File

@ -5,18 +5,23 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"net/netip"
"os/exec" "os/exec"
"strings" "strings"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
var expectedExtInt = "Ethernet1" var (
expectedExternalInt = "Ethernet1"
expectedVPNint = "wgtest0"
)
type RouteInfo struct { type RouteInfo struct {
NextHop string `json:"nexthop"` NextHop string `json:"nexthop"`
@ -43,8 +48,6 @@ type testCase struct {
dialer dialer dialer dialer
} }
var expectedVPNint = "wgtest0"
var testCases = []testCase{ var testCases = []testCase{
{ {
name: "To external host without custom dialer via vpn", name: "To external host without custom dialer via vpn",
@ -52,14 +55,14 @@ var testCases = []testCase{
expectedSourceIP: "100.64.0.1", expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "128.0.0.0/1", expectedDestPrefix: "128.0.0.0/1",
expectedNextHop: "0.0.0.0", expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0", expectedInterface: expectedVPNint,
dialer: &net.Dialer{}, dialer: &net.Dialer{},
}, },
{ {
name: "To external host with custom dialer via physical interface", name: "To external host with custom dialer via physical interface",
destination: "192.0.2.1:53", destination: "192.0.2.1:53",
expectedDestPrefix: "192.0.2.1/32", expectedDestPrefix: "192.0.2.1/32",
expectedInterface: expectedExtInt, expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(), dialer: nbnet.NewDialer(),
}, },
@ -67,24 +70,15 @@ var testCases = []testCase{
name: "To duplicate internal route with custom dialer via physical interface", name: "To duplicate internal route with custom dialer via physical interface",
destination: "10.0.0.2:53", destination: "10.0.0.2:53",
expectedDestPrefix: "10.0.0.2/32", expectedDestPrefix: "10.0.0.2/32",
expectedInterface: expectedExtInt, expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(), dialer: nbnet.NewDialer(),
}, },
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
dialer: &net.Dialer{},
},
{ {
name: "To unique vpn route with custom dialer via physical interface", name: "To unique vpn route with custom dialer via physical interface",
destination: "172.16.0.2:53", destination: "172.16.0.2:53",
expectedDestPrefix: "172.16.0.2/32", expectedDestPrefix: "172.16.0.2/32",
expectedInterface: expectedExtInt, expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(), dialer: nbnet.NewDialer(),
}, },
{ {
@ -93,7 +87,7 @@ var testCases = []testCase{
expectedSourceIP: "100.64.0.1", expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "172.16.0.0/12", expectedDestPrefix: "172.16.0.0/12",
expectedNextHop: "0.0.0.0", expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0", expectedInterface: expectedVPNint,
dialer: &net.Dialer{}, dialer: &net.Dialer{},
}, },
@ -103,22 +97,14 @@ var testCases = []testCase{
expectedSourceIP: "100.64.0.1", expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "10.10.0.0/24", expectedDestPrefix: "10.10.0.0/24",
expectedNextHop: "0.0.0.0", expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0", expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
},
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53",
expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
dialer: &net.Dialer{}, dialer: &net.Dialer{},
}, },
} }
func TestRouting(t *testing.T) { func TestRouting(t *testing.T) {
log.SetLevel(log.DebugLevel)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t) setupTestEnv(t)
@ -129,7 +115,7 @@ func TestRouting(t *testing.T) {
require.NoError(t, err, "Failed to fetch interface IP") require.NoError(t, err, "Failed to fetch interface IP")
output := testRoute(t, tc.destination, tc.dialer) output := testRoute(t, tc.destination, tc.dialer)
if tc.expectedInterface == expectedExtInt { if tc.expectedInterface == expectedExternalInt {
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
} else { } else {
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
@ -242,19 +228,23 @@ func setupDummyInterfacesAndRoutes(t *testing.T) {
func addDummyRoute(t *testing.T, dstCIDR string) { func addDummyRoute(t *testing.T, dstCIDR string) {
t.Helper() t.Helper()
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR) prefix, err := netip.ParsePrefix(dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil { if err != nil {
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output) t.Fatalf("Failed to parse destination CIDR %s: %v", dstCIDR, err)
t.FailNow() }
nexthop := Nexthop{
Intf: &net.Interface{Index: 1},
}
if err = addRoute(prefix, nexthop); err != nil {
t.Fatalf("Failed to add dummy route: %v", err)
} }
t.Cleanup(func() { t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR) err := deleteRoute(prefix, nexthop)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil { if err != nil {
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output) t.Logf("Failed to remove dummy route: %v", err)
} }
}) })
} }

View File

@ -3,11 +3,11 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
@ -19,81 +19,32 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if s.connectClient == nil { tracer, engine, err := s.getPacketTracer()
return nil, fmt.Errorf("connect client not initialized") if err != nil {
} return nil, err
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("engine not initialized")
} }
fwManager := engine.GetFirewallManager() srcAddr, err := s.parseAddress(req.GetSourceIp(), engine)
if fwManager == nil { if err != nil {
return nil, fmt.Errorf("firewall manager not initialized") return nil, fmt.Errorf("invalid source IP address: %w", err)
} }
tracer, ok := fwManager.(packetTracer) dstAddr, err := s.parseAddress(req.GetDestinationIp(), engine)
if !ok { if err != nil {
return nil, fmt.Errorf("firewall manager does not support packet tracing") return nil, fmt.Errorf("invalid destination IP address: %w", err)
} }
srcIP := net.ParseIP(req.GetSourceIp()) protocol, err := s.parseProtocol(req.GetProtocol())
if req.GetSourceIp() == "self" { if err != nil {
srcIP = engine.GetWgAddr() return nil, err
} }
srcAddr, ok := netip.AddrFromSlice(srcIP) direction, err := s.parseDirection(req.GetDirection())
if !ok { if err != nil {
return nil, fmt.Errorf("invalid source IP address") return nil, err
} }
dstIP := net.ParseIP(req.GetDestinationIp()) tcpState := s.parseTCPFlags(req.GetTcpFlags())
if req.GetDestinationIp() == "self" {
dstIP = engine.GetWgAddr()
}
dstAddr, ok := netip.AddrFromSlice(dstIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
}
if srcIP == nil || dstIP == nil {
return nil, fmt.Errorf("invalid IP address")
}
var tcpState *uspfilter.TCPState
if flags := req.GetTcpFlags(); flags != nil {
tcpState = &uspfilter.TCPState{
SYN: flags.GetSyn(),
ACK: flags.GetAck(),
FIN: flags.GetFin(),
RST: flags.GetRst(),
PSH: flags.GetPsh(),
URG: flags.GetUrg(),
}
}
var dir fw.RuleDirection
switch req.GetDirection() {
case "in":
dir = fw.RuleDirectionIN
case "out":
dir = fw.RuleDirectionOUT
default:
return nil, fmt.Errorf("invalid direction")
}
var protocol fw.Protocol
switch req.GetProtocol() {
case "tcp":
protocol = fw.ProtocolTCP
case "udp":
protocol = fw.ProtocolUDP
case "icmp":
protocol = fw.ProtocolICMP
default:
return nil, fmt.Errorf("invalid protocolcol")
}
builder := &uspfilter.PacketBuilder{ builder := &uspfilter.PacketBuilder{
SrcIP: srcAddr, SrcIP: srcAddr,
@ -101,16 +52,96 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
Protocol: protocol, Protocol: protocol,
SrcPort: uint16(req.GetSourcePort()), SrcPort: uint16(req.GetSourcePort()),
DstPort: uint16(req.GetDestinationPort()), DstPort: uint16(req.GetDestinationPort()),
Direction: dir, Direction: direction,
TCPState: tcpState, TCPState: tcpState,
ICMPType: uint8(req.GetIcmpType()), ICMPType: uint8(req.GetIcmpType()),
ICMPCode: uint8(req.GetIcmpCode()), ICMPCode: uint8(req.GetIcmpCode()),
} }
trace, err := tracer.TracePacketFromBuilder(builder) trace, err := tracer.TracePacketFromBuilder(builder)
if err != nil { if err != nil {
return nil, fmt.Errorf("trace packet: %w", err) return nil, fmt.Errorf("trace packet: %w", err)
} }
return s.buildTraceResponse(trace), nil
}
func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
if s.connectClient == nil {
return nil, nil, fmt.Errorf("connect client not initialized")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, nil, fmt.Errorf("engine not initialized")
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, nil, fmt.Errorf("firewall manager not initialized")
}
tracer, ok := fwManager.(packetTracer)
if !ok {
return nil, nil, fmt.Errorf("firewall manager does not support packet tracing")
}
return tracer, engine, nil
}
func (s *Server) parseAddress(addr string, engine *internal.Engine) (netip.Addr, error) {
if addr == "self" {
return engine.GetWgAddr(), nil
}
a, err := netip.ParseAddr(addr)
if err != nil {
return netip.Addr{}, err
}
return a.Unmap(), nil
}
func (s *Server) parseProtocol(protocol string) (fw.Protocol, error) {
switch protocol {
case "tcp":
return fw.ProtocolTCP, nil
case "udp":
return fw.ProtocolUDP, nil
case "icmp":
return fw.ProtocolICMP, nil
default:
return "", fmt.Errorf("invalid protocol")
}
}
func (s *Server) parseDirection(direction string) (fw.RuleDirection, error) {
switch direction {
case "in":
return fw.RuleDirectionIN, nil
case "out":
return fw.RuleDirectionOUT, nil
default:
return 0, fmt.Errorf("invalid direction")
}
}
func (s *Server) parseTCPFlags(flags *proto.TCPFlags) *uspfilter.TCPState {
if flags == nil {
return nil
}
return &uspfilter.TCPState{
SYN: flags.GetSyn(),
ACK: flags.GetAck(),
FIN: flags.GetFin(),
RST: flags.GetRst(),
PSH: flags.GetPsh(),
URG: flags.GetUrg(),
}
}
func (s *Server) buildTraceResponse(trace *uspfilter.PacketTrace) *proto.TracePacketResponse {
resp := &proto.TracePacketResponse{} resp := &proto.TracePacketResponse{}
for _, result := range trace.Results { for _, result := range trace.Results {
@ -119,10 +150,12 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
Message: result.Message, Message: result.Message,
Allowed: result.Allowed, Allowed: result.Allowed,
} }
if result.ForwarderAction != nil { if result.ForwarderAction != nil {
details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr) details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr)
stage.ForwardingDetails = &details stage.ForwardingDetails = &details
} }
resp.Stages = append(resp.Stages, stage) resp.Stages = append(resp.Stages, stage)
} }
@ -130,5 +163,5 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed
} }
return resp, nil return resp
} }

View File

@ -1,8 +1,10 @@
package net package net
import ( import (
"fmt"
"math/big" "math/big"
"net" "net"
"net/netip"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -54,11 +56,13 @@ func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString()) return ConnectionID(uuid.NewString())
} }
func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP { func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
// Calculate the last IP in the CIDR range
var endIP net.IP var endIP net.IP
for i := 0; i < len(network.IP); i++ { addr := network.Addr().AsSlice()
endIP = append(endIP, network.IP[i]|^network.Mask[i]) mask := net.CIDRMask(network.Bits(), len(addr)*8)
for i := 0; i < len(addr); i++ {
endIP = append(endIP, addr[i]|^mask[i])
} }
// convert to big.Int // convert to big.Int
@ -70,5 +74,10 @@ func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP {
resultInt := big.NewInt(0) resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig) resultInt.Sub(endInt, fromEndBig)
return resultInt.Bytes() ip, ok := netip.AddrFromSlice(resultInt.Bytes())
if !ok {
return netip.Addr{}, fmt.Errorf("invalid IP address from network %s", network)
}
return ip.Unmap(), nil
} }

94
util/net/net_test.go Normal file
View File

@ -0,0 +1,94 @@
package net
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
name string
network string
fromEnd int
expected string
expectErr bool
}{
{
name: "IPv4 /24 network - last IP (fromEnd=0)",
network: "192.168.1.0/24",
fromEnd: 0,
expected: "192.168.1.255",
},
{
name: "IPv4 /24 network - fromEnd=1",
network: "192.168.1.0/24",
fromEnd: 1,
expected: "192.168.1.254",
},
{
name: "IPv4 /24 network - fromEnd=5",
network: "192.168.1.0/24",
fromEnd: 5,
expected: "192.168.1.250",
},
{
name: "IPv4 /16 network - last IP",
network: "10.0.0.0/16",
fromEnd: 0,
expected: "10.0.255.255",
},
{
name: "IPv4 /16 network - fromEnd=256",
network: "10.0.0.0/16",
fromEnd: 256,
expected: "10.0.254.255",
},
{
name: "IPv4 /32 network - single host",
network: "192.168.1.100/32",
fromEnd: 0,
expected: "192.168.1.100",
},
{
name: "IPv6 /64 network - last IP",
network: "2001:db8::/64",
fromEnd: 0,
expected: "2001:db8::ffff:ffff:ffff:ffff",
},
{
name: "IPv6 /64 network - fromEnd=1",
network: "2001:db8::/64",
fromEnd: 1,
expected: "2001:db8::ffff:ffff:ffff:fffe",
},
{
name: "IPv6 /128 network - single host",
network: "2001:db8::1/128",
fromEnd: 0,
expected: "2001:db8::1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
network, err := netip.ParsePrefix(tt.network)
require.NoError(t, err, "Failed to parse network prefix")
result, err := GetLastIPFromNetwork(network, tt.fromEnd)
if tt.expectErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
expectedIP, err := netip.ParseAddr(tt.expected)
require.NoError(t, err, "Failed to parse expected IP")
assert.Equal(t, expectedIP, result, "IP mismatch for network %s with fromEnd=%d", tt.network, tt.fromEnd)
})
}
}