mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-19 17:31:39 +02:00
[client] Use platform-native routing APIs for freeBSD, macOS and Windows
This commit is contained in:
parent
87148c503f
commit
ea4d13e96d
@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
|
||||
Example: `
|
||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
|
||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||
Args: cobra.ExactArgs(3),
|
||||
RunE: tracePacket,
|
||||
|
@ -2,7 +2,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Close(nil)
|
||||
@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
ip := netip.MustParseAddr("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package nftables
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
ip := netip.MustParseAddr("100.96.0.1").Unmap()
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
|
||||
}
|
||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs2 := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
Data: ip.AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
ip := netip.MustParseAddr("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
ip := netip.MustParseAddr("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
|
@ -41,7 +41,7 @@ type Forwarder struct {
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip net.IP
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
}
|
||||
|
||||
@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
ones, _ := iface.Address().Network.Mask.Size()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||
PrefixLen: ones,
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
PrefixLen: iface.Address().Network.Bits(),
|
||||
},
|
||||
}
|
||||
|
||||
@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: iface.Address().IP,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||
if f.netstack && f.ip.Equal(addr) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
return addr.AsSlice()
|
||||
@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
|
||||
}
|
||||
|
||||
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
|
||||
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||
return value.([]byte), true
|
||||
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||
|
@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
if !ip.Is4() {
|
||||
return
|
||||
}
|
||||
ipv4 := ip.AsSlice()
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
ipStr := ipv4.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
}
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
if _, exists := ipv4Set[ip]; !exists {
|
||||
ipv4Set[ip] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
ipv4Set := make(map[netip.Addr]struct{})
|
||||
var ipv4Addresses []netip.Addr
|
||||
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
|
@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||
expected: true,
|
||||
@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||
expected: true,
|
||||
@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||
expected: true,
|
||||
@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||
expected: true,
|
||||
@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "Local IP doesn't match - addresses 32 apart",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||
expected: false,
|
||||
@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
IP: netip.MustParseAddr("fe80::1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
expected: false,
|
||||
|
@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -71,7 +71,6 @@ type Manager struct {
|
||||
// incomingRules is used for filtering and hooks
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
@ -1091,11 +1090,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
return true
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
m.wgNetwork = network
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
|
@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Apply scenario-specific setup
|
||||
sc.setupFunc(manager)
|
||||
|
||||
@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Pre-populate connection table
|
||||
srcIPs := generateRandomIPs(count)
|
||||
dstIPs := generateRandomIPs(count)
|
||||
@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
srcIP := generateRandomIPs(1)[0]
|
||||
dstIP := generateRandomIPs(1)[0]
|
||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||
@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "post_handshake",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
require.NoError(b, err)
|
||||
@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||
dst := fw.Network{Prefix: r.dest}
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
@ -19,12 +19,8 @@ import (
|
||||
)
|
||||
|
||||
func TestPeerACLFiltering(t *testing.T) {
|
||||
localIP := net.ParseIP("100.10.0.100")
|
||||
wgNet := &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
localIP := netip.MustParseAddr("100.10.0.100")
|
||||
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = wgNet
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
localIP, wgNet, err := net.ParseCIDR(network)
|
||||
require.NoError(tb, err)
|
||||
wgNet := netip.MustParsePrefix(network)
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: localIP,
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer func() {
|
||||
@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||
manager.decoders = sync.Pool{
|
||||
|
@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a.AsSlice()) {
|
||||
if u.address.Network.Contains(a) {
|
||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@ -24,9 +23,6 @@ type PacketFilter interface {
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
SetNetwork(*net.IPNet)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
|
@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
||||
log.Info("create nbnetstack tun interface")
|
||||
|
||||
// TODO: get from service listener runtime IP
|
||||
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("last ip: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("netstack using address: %s", t.address.IP)
|
||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||
|
@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||
}
|
||||
|
||||
ip := address.IP.String()
|
||||
mask := "0x" + address.Network.Mask.String()
|
||||
|
||||
// Convert prefix length to hex netmask
|
||||
prefixLen := address.Network.Bits()
|
||||
if !address.IP.Is4() {
|
||||
return fmt.Errorf("IPv6 not supported for interface assignment")
|
||||
}
|
||||
|
||||
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
||||
mask := fmt.Sprintf("0x%08x", maskBits)
|
||||
|
||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||
|
||||
|
@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
|
||||
}
|
||||
|
||||
w.filter = filter
|
||||
w.filter.SetNetwork(w.tun.WgAddress().Network)
|
||||
|
||||
w.tun.FilteredDevice().SetFilter(filter)
|
||||
return nil
|
||||
|
@ -5,7 +5,6 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
"net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||
}
|
||||
|
||||
// SetNetwork indicates an expected call of SetNetwork.
|
||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||
}
|
||||
|
@ -1,8 +1,6 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
@ -15,8 +13,8 @@ import (
|
||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||
|
||||
type NetStackTun struct { //nolint:revive
|
||||
address net.IP
|
||||
dnsAddress net.IP
|
||||
address netip.Addr
|
||||
dnsAddress netip.Addr
|
||||
mtu int
|
||||
listenAddress string
|
||||
|
||||
@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
|
||||
tundev tun.Device
|
||||
}
|
||||
|
||||
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
||||
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||
return &NetStackTun{
|
||||
address: address,
|
||||
dnsAddress: dnsAddress,
|
||||
@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
|
||||
}
|
||||
|
||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||
addr, ok := netip.AddrFromSlice(t.address)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
||||
}
|
||||
|
||||
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
||||
}
|
||||
|
||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||
[]netip.Addr{addr.Unmap()},
|
||||
[]netip.Addr{dnsAddr.Unmap()},
|
||||
[]netip.Addr{t.address},
|
||||
[]netip.Addr{t.dnsAddress},
|
||||
t.mtu)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -2,28 +2,27 @@ package wgaddr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// Address WireGuard parsed address
|
||||
type Address struct {
|
||||
IP net.IP
|
||||
Network *net.IPNet
|
||||
IP netip.Addr
|
||||
Network netip.Prefix
|
||||
}
|
||||
|
||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||
func ParseWGAddress(address string) (Address, error) {
|
||||
ip, network, err := net.ParseCIDR(address)
|
||||
prefix, err := netip.ParsePrefix(address)
|
||||
if err != nil {
|
||||
return Address{}, err
|
||||
}
|
||||
return Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
IP: prefix.Addr().Unmap(),
|
||||
Network: prefix.Masked(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (addr Address) String() string {
|
||||
maskSize, _ := addr.Network.Mask.Size()
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
@ -43,12 +43,11 @@ func TestDefaultManager(t *testing.T) {
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
require.NoError(t, err)
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
@ -162,12 +161,11 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
require.NoError(t, err)
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
@ -372,12 +370,11 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
require.NoError(t, err)
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: ip,
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
@ -2,7 +2,7 @@ package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@ -12,13 +12,14 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
||||
ip := net.ParseIP(aRecord.RData)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
|
||||
ip, err := netip.ParseAddr(aRecord.RData)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
if !ipNet.Contains(ip) {
|
||||
if !prefix.Contains(ip) {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
|
||||
}
|
||||
|
||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
||||
maskOnes, _ := ipNet.Mask.Size()
|
||||
func generateReverseZoneName(network netip.Prefix) (string, error) {
|
||||
networkIP := network.Masked().Addr()
|
||||
|
||||
if !networkIP.Is4() {
|
||||
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
|
||||
}
|
||||
|
||||
// round up to nearest byte
|
||||
octetsToUse := (maskOnes + 7) / 8
|
||||
octetsToUse := (network.Bits() + 7) / 8
|
||||
|
||||
octets := strings.Split(networkIP.String(), ".")
|
||||
if octetsToUse > len(octets) {
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
|
||||
}
|
||||
|
||||
reverseOctets := make([]string, octetsToUse)
|
||||
@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
|
||||
}
|
||||
|
||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
||||
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
||||
continue
|
||||
}
|
||||
|
||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
||||
if ptrRecord, ok := createPTRRecord(record, prefix); ok {
|
||||
records = append(records, ptrRecord)
|
||||
}
|
||||
}
|
||||
@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
||||
}
|
||||
|
||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
zoneName, err := generateReverseZoneName(ipNet)
|
||||
func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
||||
zoneName, err := generateReverseZoneName(network)
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return
|
||||
@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
return
|
||||
}
|
||||
|
||||
records := collectPTRRecords(config, ipNet)
|
||||
records := collectPTRRecords(config, network)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
|
@ -46,10 +46,9 @@ func (w *mocWGIface) Name() string {
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() wgaddr.Address {
|
||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||
return wgaddr.Address{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
IP: netip.MustParseAddr("100.66.100.1"),
|
||||
Network: netip.MustParsePrefix("100.66.100.0/24"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -464,17 +463,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
||||
if err != nil {
|
||||
t.Errorf("parse CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
|
@ -24,11 +24,15 @@ type ServiceViaMemory struct {
|
||||
}
|
||||
|
||||
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
|
||||
if err != nil {
|
||||
log.Errorf("get last ip from network: %v", err)
|
||||
}
|
||||
s := &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(),
|
||||
runtimeIP: lastIP.String(),
|
||||
runtimePort: defaultPort,
|
||||
}
|
||||
return s
|
||||
@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||
if s.wgInterface.Address().IP.Is6() {
|
||||
firstLayerDecoder = layers.LayerTypeIPv6
|
||||
}
|
||||
|
||||
|
@ -1,33 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||
{"192.168.0.0/30", "192.168.0.2"},
|
||||
{"192.168.0.0/16", "192.168.255.254"},
|
||||
{"192.168.0.0/24", "192.168.0.254"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@ -23,8 +24,8 @@ type upstreamResolver struct {
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ net.IP,
|
||||
_ *net.IPNet,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
|
@ -4,7 +4,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@ -19,8 +19,8 @@ type upstreamResolver struct {
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ net.IP,
|
||||
_ *net.IPNet,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@ -18,16 +19,16 @@ import (
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
*upstreamResolverBase
|
||||
lIP net.IP
|
||||
lNet *net.IPNet
|
||||
lIP netip.Addr
|
||||
lNet netip.Prefix
|
||||
interfaceName string
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
interfaceName string,
|
||||
ip net.IP,
|
||||
net *net.IPNet,
|
||||
ip netip.Addr,
|
||||
net netip.Prefix,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
@ -58,8 +59,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
}
|
||||
client.DialTimeout = timeout
|
||||
|
||||
upstreamIP := net.ParseIP(upstreamHost)
|
||||
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
||||
upstreamIP, err := netip.ParseAddr(upstreamHost)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
|
||||
}
|
||||
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
||||
log.Debugf("using private client to query upstream: %s", upstream)
|
||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||
if err != nil {
|
||||
@ -73,7 +77,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// This method is needed for iOS
|
||||
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
@ -82,7 +86,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
|
||||
|
||||
dialer := &net.Dialer{
|
||||
LocalAddr: &net.UDPAddr{
|
||||
IP: ip,
|
||||
IP: ip.AsSlice(),
|
||||
Port: 0, // Let the OS pick a free port
|
||||
},
|
||||
Timeout: dialTimeout,
|
||||
|
@ -2,7 +2,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
|
||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||
resolver.upstreamServers = testCase.InputServers
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
|
@ -1008,7 +1008,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
log.Errorf("failed to update routes: %v", err)
|
||||
}
|
||||
|
||||
if e.acl != nil {
|
||||
@ -1104,7 +1104,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
|
||||
convertedRoute := &route.Route{
|
||||
ID: route.ID(protoRoute.ID),
|
||||
Network: prefix,
|
||||
Network: prefix.Masked(),
|
||||
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
||||
NetID: route.NetID(protoRoute.NetID),
|
||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||
@ -1138,7 +1138,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
||||
return entries
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||
CustomZones: make([]nbdns.CustomZone, 0),
|
||||
@ -1790,9 +1790,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||
}
|
||||
|
||||
// GetWgAddr returns the wireguard address
|
||||
func (e *Engine) GetWgAddr() net.IP {
|
||||
func (e *Engine) GetWgAddr() netip.Addr {
|
||||
if e.wgInterface == nil {
|
||||
return nil
|
||||
return netip.Addr{}
|
||||
}
|
||||
return e.wgInterface.Address().IP
|
||||
}
|
||||
@ -1861,12 +1861,7 @@ func (e *Engine) Address() (netip.Addr, error) {
|
||||
return netip.Addr{}, errors.New("wireguard interface not initialized")
|
||||
}
|
||||
|
||||
addr := e.wgInterface.Address()
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
|
||||
}
|
||||
return ip.Unmap(), nil
|
||||
return e.wgInterface.Address().IP, nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||
|
@ -371,11 +371,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
}
|
||||
},
|
||||
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
|
@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
|
||||
|
||||
// fallback if mark rules are not in place
|
||||
wgnet := c.iface.Address().Network
|
||||
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice())
|
||||
return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
|
||||
}
|
||||
|
||||
// mapRxPackets maps packet counts to RX based on flow direction
|
||||
@ -293,17 +293,15 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
|
||||
// fallback if marks are not set
|
||||
wgaddr := c.iface.Address().IP
|
||||
wgnetwork := c.iface.Address().Network
|
||||
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
|
||||
|
||||
switch {
|
||||
case wgaddr.Equal(src):
|
||||
case wgaddr == srcIP:
|
||||
return nftypes.Egress
|
||||
case wgaddr.Equal(dst):
|
||||
case wgaddr == dstIP:
|
||||
return nftypes.Ingress
|
||||
case wgnetwork.Contains(src):
|
||||
case wgnetwork.Contains(srcIP):
|
||||
// netbird network -> resource network
|
||||
return nftypes.Ingress
|
||||
case wgnetwork.Contains(dst):
|
||||
case wgnetwork.Contains(dstIP):
|
||||
// resource network -> netbird network
|
||||
return nftypes.Egress
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -23,17 +23,16 @@ type Logger struct {
|
||||
rcvChan atomic.Pointer[rcvChan]
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgIfaceIPNet net.IPNet
|
||||
wgIfaceNet netip.Prefix
|
||||
dnsCollection atomic.Bool
|
||||
exitNodeCollection atomic.Bool
|
||||
Store types.Store
|
||||
}
|
||||
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
|
||||
return &Logger{
|
||||
statusRecorder: statusRecorder,
|
||||
wgIfaceIPNet: wgIfaceIPNet,
|
||||
wgIfaceNet: wgIfaceIPNet,
|
||||
Store: store.NewMemoryStore(),
|
||||
}
|
||||
}
|
||||
@ -89,11 +88,11 @@ func (l *Logger) startReceiver() {
|
||||
var isSrcExitNode bool
|
||||
var isDestExitNode bool
|
||||
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
|
||||
if !l.wgIfaceNet.Contains(event.SourceIP) {
|
||||
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
|
||||
}
|
||||
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
|
||||
if !l.wgIfaceNet.Contains(event.DestIP) {
|
||||
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
package logger_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
logger := logger.New(nil, net.IPNet{})
|
||||
logger := logger.New(nil, netip.Prefix{})
|
||||
logger.Enable()
|
||||
|
||||
event := types.EventFields{
|
||||
|
@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@ -34,11 +34,11 @@ type Manager struct {
|
||||
|
||||
// NewManager creates a new netflow manager
|
||||
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
var ipNet net.IPNet
|
||||
var prefix netip.Prefix
|
||||
if iface != nil {
|
||||
ipNet = *iface.Address().Network
|
||||
prefix = iface.Address().Network
|
||||
}
|
||||
flowLogger := logger.New(statusRecorder, ipNet)
|
||||
flowLogger := logger.New(statusRecorder, prefix)
|
||||
|
||||
var ct nftypes.ConnTracker
|
||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||
|
@ -1,7 +1,7 @@
|
||||
package netflow
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -33,10 +33,7 @@ func (m *mockIFaceMapper) IsUserspaceBind() bool {
|
||||
func TestManager_Update(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
Network: netip.MustParsePrefix("192.168.1.1/32"),
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
@ -102,10 +99,7 @@ func TestManager_Update(t *testing.T) {
|
||||
func TestManager_Update_TokenPreservation(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
Network: netip.MustParsePrefix("192.168.1.1/32"),
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
@ -264,7 +264,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ip, ip.BitLen())
|
||||
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
|
||||
newPrefixes = append(newPrefixes, prefix)
|
||||
}
|
||||
|
||||
|
@ -333,11 +333,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||
|
||||
var merr *multierror.Error
|
||||
if !m.disableClientRoutes {
|
||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||
|
||||
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
|
||||
log.Errorf("Failed to update system routes: %v", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err))
|
||||
}
|
||||
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
@ -346,14 +347,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
m.clientRoutes = newClientRoutesIDMap
|
||||
|
||||
if m.serverRouter == nil {
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
|
||||
return fmt.Errorf("update routes: %w", err)
|
||||
merr = multierror.Append(merr, fmt.Errorf("update server routes: %w", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// SetRouteChangeListener set RouteListener for route change Notifier
|
||||
|
@ -44,7 +44,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -71,7 +71,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.252.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.252.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -99,7 +99,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.30.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.30.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -127,7 +127,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.30.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.30.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -211,7 +211,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -233,7 +233,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -250,7 +250,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -272,7 +272,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -282,7 +282,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "b",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey2,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -299,7 +299,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -327,7 +327,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -356,7 +356,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "l1",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -376,7 +376,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
ID: "r1",
|
||||
NetID: "routeA",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
||||
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
@ -440,11 +440,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
if len(testCase.inputInitRoutes) > 0 {
|
||||
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
|
||||
require.NoError(t, err, "should update routes with init routes")
|
||||
}
|
||||
|
||||
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
|
||||
require.NoError(t, err, "should update routes")
|
||||
|
||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||
|
@ -13,7 +13,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -22,8 +22,13 @@ const (
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
type iface interface {
|
||||
Address() wgaddr.Address
|
||||
Name() string
|
||||
}
|
||||
|
||||
// Setup configures sysctl settings for RP filtering and source validation.
|
||||
func Setup(wgIface iface.WGIface) (map[string]int, error) {
|
||||
func Setup(wgIface iface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
|
@ -6,9 +6,10 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
type Nexthop struct {
|
||||
@ -30,11 +31,16 @@ func (n Nexthop) String() string {
|
||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||
}
|
||||
|
||||
type wgIface interface {
|
||||
Address() wgaddr.Address
|
||||
Name() string
|
||||
}
|
||||
|
||||
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
|
||||
|
||||
type SysOps struct {
|
||||
refCounter *ExclusionCounter
|
||||
wgInterface iface.WGIface
|
||||
wgInterface wgIface
|
||||
// prefixes is tracking all the current added prefixes im memory
|
||||
// (this is used in iOS as all route updates require a full table update)
|
||||
//nolint
|
||||
@ -45,9 +51,27 @@ type SysOps struct {
|
||||
notifier *notifier.Notifier
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps {
|
||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
notifier: notifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
|
||||
switch {
|
||||
case
|
||||
!addr.IsValid(),
|
||||
addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsMulticast(),
|
||||
addr.IsUnspecified() && prefix.Bits() != 0,
|
||||
r.wgInterface.Address().Network.Contains(addr):
|
||||
return vars.ErrRouteNotAllowed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@ -33,7 +35,12 @@ func init() {
|
||||
|
||||
func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
intf := &net.Interface{Name: "lo0"}
|
||||
|
||||
var intf *net.Interface
|
||||
var nexthop Nexthop
|
||||
|
||||
_, intf = setupDummyInterface(t)
|
||||
nexthop = Nexthop{netip.Addr{}, intf}
|
||||
|
||||
r := NewSysOps(nil, nil)
|
||||
|
||||
@ -43,7 +50,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
||||
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
@ -59,7 +66,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
||||
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
@ -119,18 +126,39 @@ func TestBits(t *testing.T) {
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
if runtime.GOOS == "darwin" {
|
||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||
require.NoError(t, err, "Failed to create loopback alias")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
})
|
||||
|
||||
return intf
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||
require.NoError(t, err, "Failed to parse prefix")
|
||||
|
||||
netIntf, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||
|
||||
r := NewSysOps(nil, nil)
|
||||
err = r.addToRouteTable(prefix, nexthop)
|
||||
require.NoError(t, err, "Failed to add route to table")
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||
err := r.removeFromRouteTable(prefix, nexthop)
|
||||
assert.NoError(t, err, "Failed to remove route from table")
|
||||
})
|
||||
|
||||
return "lo0"
|
||||
return intf
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
|
||||
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||
t.Helper()
|
||||
|
||||
var originalNexthop net.IP
|
||||
@ -176,12 +204,40 @@ func fetchOriginalGateway() (net.IP, error) {
|
||||
return net.ParseIP(matches[1]), nil
|
||||
}
|
||||
|
||||
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "darwin" {
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||
}
|
||||
|
||||
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||
|
||||
tunName := strings.TrimSpace(string(output))
|
||||
|
||||
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||
|
||||
intf, err := net.InterfaceByName(tunName)
|
||||
require.NoError(t, err, "Failed to get interface by name")
|
||||
|
||||
t.Cleanup(func() {
|
||||
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||
}
|
||||
})
|
||||
|
||||
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
||||
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||
|
||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
||||
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||
}
|
||||
|
@ -17,7 +17,6 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
@ -106,59 +105,15 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: fix: for default our wg address now appears as the default gw
|
||||
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
addr := netip.IPv4Unspecified()
|
||||
if prefix.Addr().Is6() {
|
||||
addr = netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
nexthop, err := GetNextHop(addr)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return fmt.Errorf("get existing route gateway: %s", err)
|
||||
}
|
||||
|
||||
if !prefix.Contains(nexthop.IP) {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
|
||||
if nexthop.IP.Is6() {
|
||||
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
|
||||
}
|
||||
|
||||
ok, err := existsInRouteTable(gatewayPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
nexthop, err = GetNextHop(nexthop.IP)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
|
||||
return r.addToRouteTable(gatewayPrefix, nexthop)
|
||||
}
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsUnspecified(),
|
||||
addr.IsMulticast():
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
if err := r.validateRoute(prefix); err != nil {
|
||||
return Nexthop{}, err
|
||||
}
|
||||
|
||||
addr := prefix.Addr()
|
||||
if addr.IsUnspecified() {
|
||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
@ -179,10 +134,7 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface
|
||||
Intf: nexthop.Intf,
|
||||
}
|
||||
|
||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||
if !ok {
|
||||
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
}
|
||||
vpnAddr := vpnIntf.Address().IP
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
||||
@ -271,32 +223,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.addNonExistingRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
|
||||
return r.addToRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||
@ -408,12 +335,8 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||
|
||||
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||
if gateway == nil {
|
||||
if runtime.GOOS == "freebsd" {
|
||||
return Nexthop{Intf: intf}, nil
|
||||
}
|
||||
|
||||
if preferredSrc == nil {
|
||||
return Nexthop{}, vars.ErrRouteNotFound
|
||||
return Nexthop{Intf: intf}, nil
|
||||
}
|
||||
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
|
||||
|
||||
@ -457,32 +380,6 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute == prefix {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
||||
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||
localRoutes, err := hasSeparateRouting()
|
||||
|
@ -3,23 +3,25 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
@ -27,105 +29,370 @@ type dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func TestAddRemoveRoutes(t *testing.T) {
|
||||
func TestAddVPNRoute(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
shouldRouteToWireguard bool
|
||||
shouldBeRemoved bool
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Should Add And Remove Route 100.66.120.0/24",
|
||||
prefix: netip.MustParsePrefix("100.66.120.0/24"),
|
||||
shouldRouteToWireguard: true,
|
||||
shouldBeRemoved: true,
|
||||
name: "IPv4 - Private network route",
|
||||
prefix: netip.MustParsePrefix("10.10.100.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Or Remove Route 127.0.0.1/32",
|
||||
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||
shouldRouteToWireguard: false,
|
||||
shouldBeRemoved: false,
|
||||
name: "IPv4 Single host",
|
||||
prefix: netip.MustParsePrefix("10.111.111.111/32"),
|
||||
},
|
||||
{
|
||||
name: "IPv4 RFC3927 test range",
|
||||
prefix: netip.MustParsePrefix("198.51.100.0/24"),
|
||||
},
|
||||
{
|
||||
name: "IPv4 Default route",
|
||||
prefix: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
|
||||
{
|
||||
name: "IPv6 Subnet",
|
||||
prefix: netip.MustParsePrefix("fdb1:848a:7e16::/48"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 Single host",
|
||||
prefix: netip.MustParsePrefix("fdb1:848a:7e16:a::b/128"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 Default route",
|
||||
prefix: netip.MustParsePrefix("::/0"),
|
||||
},
|
||||
|
||||
// IPv4 addresses that should be rejected (matches validateRoute logic)
|
||||
{
|
||||
name: "IPv4 Loopback",
|
||||
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Link-local unicast",
|
||||
prefix: netip.MustParsePrefix("169.254.1.1/32"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Link-local multicast",
|
||||
prefix: netip.MustParsePrefix("224.0.0.251/32"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Multicast",
|
||||
prefix: netip.MustParsePrefix("239.255.255.250/32"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Unspecified with prefix",
|
||||
prefix: netip.MustParsePrefix("0.0.0.0/32"),
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// IPv6 addresses that should be rejected (matches validateRoute logic)
|
||||
{
|
||||
name: "IPv6 Loopback",
|
||||
prefix: netip.MustParsePrefix("::1/128"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Link-local unicast",
|
||||
prefix: netip.MustParsePrefix("fe80::1/128"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Link-local multicast",
|
||||
prefix: netip.MustParsePrefix("ff02::1/128"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Interface-local multicast",
|
||||
prefix: netip.MustParsePrefix("ff01::1/128"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Multicast",
|
||||
prefix: netip.MustParsePrefix("ff00::1/128"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Unspecified with prefix",
|
||||
prefix: netip.MustParsePrefix("::/128"),
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "IPv4 WireGuard interface network overlap",
|
||||
prefix: netip.MustParsePrefix("100.65.75.0/24"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv4 WireGuard interface network subnet",
|
||||
prefix: netip.MustParsePrefix("100.65.75.0/32"),
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
// todo resolve test execution on freebsd
|
||||
if runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping ", testCase.name, " on freebsd")
|
||||
}
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun53%d", n),
|
||||
Address: "100.65.75.2/24",
|
||||
WGPrivKey: peerPrivateKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(opts)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
|
||||
_, _, err = r.SetupRouting(nil, nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
intf, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
// add the route
|
||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
||||
if testCase.expectError {
|
||||
assert.ErrorIs(t, err, vars.ErrRouteNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if testCase.shouldRouteToWireguard {
|
||||
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
||||
// validate it's pointing to the WireGuard interface
|
||||
require.NoError(t, err)
|
||||
|
||||
nextHop := getNextHop(t, testCase.prefix.Addr())
|
||||
assert.Equal(t, wgInterface.Name(), nextHop.Intf.Name, "next hop interface should be WireGuard interface")
|
||||
|
||||
// remove route again
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err)
|
||||
|
||||
// validate it's gone
|
||||
nextHop, err = GetNextHop(testCase.prefix.Addr())
|
||||
require.True(t,
|
||||
errors.Is(err, vars.ErrRouteNotFound) || err == nil && nextHop.Intf != nil && nextHop.Intf.Name != wgInterface.Name(),
|
||||
"err: %v, next hop: %v", err, nextHop)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getNextHop(t *testing.T, addr netip.Addr) Nexthop {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "linux" {
|
||||
nextHop, err := GetNextHop(addr)
|
||||
|
||||
if runtime.GOOS == "windows" && errors.Is(err, vars.ErrRouteNotFound) && addr.Is6() {
|
||||
// TODO: Fix this test. It doesn't return the route when running in a windows github runner, but it is
|
||||
// present in the route table.
|
||||
t.Skip("Skipping windows test")
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, nextHop.Intf, "next hop interface should not be nil for %s", addr)
|
||||
|
||||
return nextHop
|
||||
}
|
||||
// GetNextHop for bsd is buggy and returns the wrong interface for the default route.
|
||||
|
||||
if addr.IsUnspecified() {
|
||||
// On macOS, querying 0.0.0.0 returns the wrong interface
|
||||
if addr.Is4() {
|
||||
addr = netip.MustParseAddr("1.2.3.4")
|
||||
} else {
|
||||
addr = netip.MustParseAddr("2001:db8::1")
|
||||
}
|
||||
}
|
||||
|
||||
cmd := exec.Command("route", "-n", "get", addr.String())
|
||||
if addr.Is6() {
|
||||
cmd = exec.Command("route", "-n", "get", "-inet6", addr.String())
|
||||
}
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
t.Logf("route output: %s", output)
|
||||
require.NoError(t, err, "%s failed")
|
||||
|
||||
lines := strings.Split(string(output), "\n")
|
||||
var intf string
|
||||
var gateway string
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "interface:") {
|
||||
intf = strings.TrimSpace(strings.TrimPrefix(line, "interface:"))
|
||||
} else if strings.HasPrefix(line, "gateway:") {
|
||||
gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:"))
|
||||
}
|
||||
}
|
||||
|
||||
require.NotEmpty(t, intf, "interface should be found in route output")
|
||||
|
||||
iface, err := net.InterfaceByName(intf)
|
||||
require.NoError(t, err, "interface %s should exist", intf)
|
||||
|
||||
nexthop := Nexthop{Intf: iface}
|
||||
|
||||
if gateway != "" && gateway != "link#"+strconv.Itoa(iface.Index) {
|
||||
addr, err := netip.ParseAddr(gateway)
|
||||
if err == nil {
|
||||
nexthop.IP = addr
|
||||
}
|
||||
}
|
||||
|
||||
return nexthop
|
||||
}
|
||||
|
||||
func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
expectError bool
|
||||
errorType error
|
||||
}{
|
||||
{
|
||||
name: "IPv4 RFC3927 test range",
|
||||
prefix: netip.MustParsePrefix("198.51.100.0/24"),
|
||||
},
|
||||
{
|
||||
name: "IPv4 Single host",
|
||||
prefix: netip.MustParsePrefix("8.8.8.8/32"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 External network route",
|
||||
prefix: netip.MustParsePrefix("2001:db8:1000::/48"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 Single host",
|
||||
prefix: netip.MustParsePrefix("2001:db8::1/128"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 Subnet",
|
||||
prefix: netip.MustParsePrefix("2a05:d014:1f8d::/48"),
|
||||
},
|
||||
{
|
||||
name: "IPv6 Single host",
|
||||
prefix: netip.MustParsePrefix("2a05:d014:1f8d:7302:ebca:ec15:b24d:d07e/128"),
|
||||
},
|
||||
|
||||
// Addresses that should be rejected
|
||||
{
|
||||
name: "IPv4 Loopback",
|
||||
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Link-local unicast",
|
||||
prefix: netip.MustParsePrefix("169.254.1.1/32"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Multicast",
|
||||
prefix: netip.MustParsePrefix("239.255.255.250/32"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv4 Unspecified",
|
||||
prefix: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Loopback",
|
||||
prefix: netip.MustParsePrefix("::1/128"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Link-local unicast",
|
||||
prefix: netip.MustParsePrefix("fe80::1/128"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Multicast",
|
||||
prefix: netip.MustParsePrefix("ff00::1/128"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv6 Unspecified",
|
||||
prefix: netip.MustParsePrefix("::/0"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "IPv4 WireGuard interface network overlap",
|
||||
prefix: netip.MustParsePrefix("100.65.75.0/24"),
|
||||
expectError: true,
|
||||
errorType: vars.ErrRouteNotAllowed,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
})
|
||||
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
require.NoError(t, err, "Should be able to get IPv4 default route")
|
||||
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
|
||||
|
||||
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||
if testCase.prefix.Addr().Is6() &&
|
||||
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
|
||||
t.Skip("Skipping test as no ipv6 default route is available")
|
||||
}
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
t.Fatalf("Failed to get IPv6 default route: %v", err)
|
||||
}
|
||||
|
||||
var initialNextHop Nexthop
|
||||
if testCase.prefix.Addr().Is6() {
|
||||
initialNextHop = initialNextHopV6
|
||||
} else {
|
||||
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
|
||||
initialNextHop = initialNextHopV4
|
||||
}
|
||||
exists, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||
if exists && testCase.shouldRouteToWireguard {
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||
|
||||
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
|
||||
require.NoError(t, err, "GetNextHop should not return err")
|
||||
nexthop, err := r.addRouteToNonVPNIntf(testCase.prefix, wgInterface, initialNextHop)
|
||||
|
||||
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if testCase.shouldBeRemoved {
|
||||
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
|
||||
} else {
|
||||
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
|
||||
}
|
||||
if testCase.expectError {
|
||||
require.ErrorIs(t, err, vars.ErrRouteNotAllowed)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
t.Logf("Next hop for %s: %s", testCase.prefix, nexthop)
|
||||
|
||||
// Verify the route was added and points to non-VPN interface
|
||||
currentNextHop, err := GetNextHop(testCase.prefix.Addr())
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, wgInterface.Name(), currentNextHop.Intf.Name, "Route should not point to VPN interface")
|
||||
|
||||
err = r.removeFromRouteTable(testCase.prefix, nexthop)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNextHop(t *testing.T) {
|
||||
if runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping on freebsd")
|
||||
}
|
||||
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
defaultNh, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
if !nexthop.IP.IsValid() {
|
||||
if !defaultNh.IP.IsValid() {
|
||||
t.Fatal("should return a gateway")
|
||||
}
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
@ -133,7 +400,6 @@ func TestGetNextHop(t *testing.T) {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var testingIP string
|
||||
var testingPrefix netip.Prefix
|
||||
for _, address := range addresses {
|
||||
if address.Network() != "ip+net" {
|
||||
@ -141,213 +407,23 @@ func TestGetNextHop(t *testing.T) {
|
||||
}
|
||||
prefix := netip.MustParsePrefix(address.String())
|
||||
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
|
||||
testingIP = prefix.Addr().String()
|
||||
testingPrefix = prefix.Masked()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
localIP, err := GetNextHop(testingPrefix.Addr())
|
||||
nh, err := GetNextHop(testingPrefix.Addr())
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error: ", err)
|
||||
}
|
||||
if !localIP.IP.IsValid() {
|
||||
if nh.Intf == nil {
|
||||
t.Fatal("should return a gateway for local network")
|
||||
}
|
||||
if localIP.IP.String() == nexthop.IP.String() {
|
||||
t.Fatal("local IP should not match with gateway IP")
|
||||
if nh.IP.String() == defaultNh.IP.String() {
|
||||
t.Fatal("next hop IP should not match with default gateway IP")
|
||||
}
|
||||
if localIP.IP.String() != testingIP {
|
||||
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
t.Log("defaultNexthop: ", defaultNexthop)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix netip.Prefix
|
||||
preExistingPrefix netip.Prefix
|
||||
shouldAddRoute bool
|
||||
}{
|
||||
{
|
||||
name: "Should Add And Remove random Route",
|
||||
prefix: netip.MustParsePrefix("99.99.99.99/32"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if overlaps with default gateway",
|
||||
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
{
|
||||
name: "Should Add Route if bigger network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Add Route if smaller network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
shouldAddRoute: true,
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if same network exists",
|
||||
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
|
||||
var buf bytes.Buffer
|
||||
log.SetOutput(&buf)
|
||||
defer func() {
|
||||
log.SetOutput(os.Stderr)
|
||||
}()
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun53%d", n),
|
||||
Address: "100.65.75.2/24",
|
||||
WGPort: 33100,
|
||||
WGPrivKey: peerPrivateKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(opts)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
|
||||
// Prepare the environment
|
||||
if testCase.preExistingPrefix.IsValid() {
|
||||
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding route")
|
||||
|
||||
if testCase.shouldAddRoute {
|
||||
// test if route exists after adding
|
||||
ok, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "should not return err")
|
||||
require.True(t, ok, "route should exist")
|
||||
|
||||
// remove route again if added
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err")
|
||||
}
|
||||
|
||||
// route should either not have been added or should have been removed
|
||||
// In case of already existing route, it should not have been added (but still exist)
|
||||
ok, err := existsInRouteTable(testCase.prefix)
|
||||
t.Log("Buffer string: ", buf.String())
|
||||
require.NoError(t, err, "should not return err")
|
||||
|
||||
if !strings.Contains(buf.String(), "because it already exists") {
|
||||
require.False(t, ok, "route should not exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSubRange(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var subRangeAddressPrefixes []netip.Prefix
|
||||
var nonSubRangeAddressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
|
||||
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
|
||||
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
|
||||
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range subRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if !isSubRangePrefix {
|
||||
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range nonSubRangeAddressPrefixes {
|
||||
isSubRangePrefix, err := isSubRange(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
||||
}
|
||||
if isSubRangePrefix {
|
||||
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistsInRouteTable(t *testing.T) {
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||
}
|
||||
|
||||
var addressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
|
||||
switch {
|
||||
case p.Addr().Is6():
|
||||
continue
|
||||
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
||||
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
|
||||
continue
|
||||
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
||||
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
|
||||
continue
|
||||
// FreeBSD loopback 127/8 is not added to the routing table
|
||||
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
|
||||
continue
|
||||
default:
|
||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range addressPrefixes {
|
||||
exists, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("address %s should exist in route table", prefix)
|
||||
}
|
||||
if nh.Intf.Name != defaultNh.Intf.Name {
|
||||
t.Fatalf("next hop interface name should match with default gateway interface name, got: %s, want: %s", nh.Intf.Name, defaultNh.Intf.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@ -384,11 +460,16 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
||||
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
err := r.AddVPNRoute(prefix, intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
if err := r.AddVPNRoute(prefix, intf); err != nil {
|
||||
if !errors.Is(err, syscall.EEXIST) && !errors.Is(err, vars.ErrRouteNotAllowed) {
|
||||
t.Fatalf("addVPNRoute should not return err: %v", err)
|
||||
}
|
||||
t.Logf("addVPNRoute %v returned: %v", prefix, err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
err = r.RemoveVPNRoute(prefix, intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
if err := r.RemoveVPNRoute(prefix, intf); err != nil && !errors.Is(err, vars.ErrRouteNotAllowed) {
|
||||
t.Fatalf("removeVPNRoute should not return err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -422,28 +503,10 @@ func setupTestEnv(t *testing.T) {
|
||||
// 10.10.0.0/24 more specific route exists in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
|
||||
// 127.0.10.0/24 more specific route exists in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
|
||||
// unique route in vpn table
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
}
|
||||
|
||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
|
||||
return
|
||||
}
|
||||
|
||||
prefixNexthop, err := GetNextHop(prefix.Addr())
|
||||
require.NoError(t, err, "GetNextHop should not return err")
|
||||
if invert {
|
||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
|
||||
} else {
|
||||
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -149,6 +149,10 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
|
||||
}
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if err := r.validateRoute(prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !nbnet.AdvancedRouting() {
|
||||
return r.genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
@ -172,6 +176,10 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if err := r.validateRoute(prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !nbnet.AdvancedRouting() {
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
@ -219,7 +227,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||
|
||||
ones, _ := route.Dst.Mask.Size()
|
||||
|
||||
prefix := netip.PrefixFrom(addr, ones)
|
||||
prefix := netip.PrefixFrom(addr.Unmap(), ones)
|
||||
if prefix.IsValid() {
|
||||
prefixList = append(prefixList, prefix)
|
||||
}
|
||||
@ -247,7 +255,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
}
|
||||
|
||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
||||
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
|
||||
return fmt.Errorf("netlink add route: %w", err)
|
||||
}
|
||||
|
||||
@ -270,7 +278,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
||||
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
|
||||
return fmt.Errorf("netlink add unreachable route: %w", err)
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,6 @@ import (
|
||||
)
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
var expectedLoopbackInt = "lo"
|
||||
var expectedExternalInt = "dummyext0"
|
||||
var expectedInternalInt = "dummyint0"
|
||||
|
||||
@ -31,12 +30,6 @@ func init() {
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
|
||||
},
|
||||
{
|
||||
name: "To more specific route (local) without custom dialer via physical interface",
|
||||
expectedInterface: expectedLoopbackInt,
|
||||
dialer: &net.Dialer{},
|
||||
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
|
@ -11,10 +11,16 @@ import (
|
||||
)
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if err := r.validateRoute(prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
return r.genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if err := r.validateRoute(prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
|
268
client/internal/routemanager/systemops/systemops_test.go
Normal file
268
client/internal/routemanager/systemops/systemops_test.go
Normal file
@ -0,0 +1,268 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
type mockWGIface struct {
|
||||
address wgaddr.Address
|
||||
name string
|
||||
}
|
||||
|
||||
func (m *mockWGIface) Address() wgaddr.Address {
|
||||
return m.address
|
||||
}
|
||||
|
||||
func (m *mockWGIface) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func TestSysOps_validateRoute(t *testing.T) {
|
||||
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
|
||||
mockWG := &mockWGIface{
|
||||
address: wgaddr.Address{
|
||||
IP: wgNetwork.Addr(),
|
||||
Network: wgNetwork,
|
||||
},
|
||||
name: "wg0",
|
||||
}
|
||||
|
||||
sysOps := &SysOps{
|
||||
wgInterface: mockWG,
|
||||
notifier: ¬ifier.Notifier{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
expectError bool
|
||||
}{
|
||||
// Valid routes
|
||||
{
|
||||
name: "valid IPv4 route",
|
||||
prefix: "192.168.1.0/24",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid IPv6 route",
|
||||
prefix: "2001:db8::/32",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid single IPv4 host",
|
||||
prefix: "8.8.8.8/32",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid single IPv6 host",
|
||||
prefix: "2001:4860:4860::8888/128",
|
||||
expectError: false,
|
||||
},
|
||||
|
||||
// Invalid routes - loopback
|
||||
{
|
||||
name: "IPv4 loopback",
|
||||
prefix: "127.0.0.1/32",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
prefix: "::1/128",
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// Invalid routes - link-local unicast
|
||||
{
|
||||
name: "IPv4 link-local unicast",
|
||||
prefix: "169.254.1.1/32",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 link-local unicast",
|
||||
prefix: "fe80::1/128",
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// Invalid routes - multicast
|
||||
{
|
||||
name: "IPv4 multicast",
|
||||
prefix: "224.0.0.1/32",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 multicast",
|
||||
prefix: "ff02::1/128",
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// Invalid routes - link-local multicast
|
||||
{
|
||||
name: "IPv4 link-local multicast",
|
||||
prefix: "224.0.0.0/24",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 link-local multicast",
|
||||
prefix: "ff02::/16",
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// Invalid routes - interface-local multicast (IPv6 only)
|
||||
{
|
||||
name: "IPv6 interface-local multicast",
|
||||
prefix: "ff01::1/128",
|
||||
expectError: true,
|
||||
},
|
||||
|
||||
// Invalid routes - overlaps with WG interface network
|
||||
{
|
||||
name: "overlaps with WG network - exact match",
|
||||
prefix: "10.0.0.0/24",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "overlaps with WG network - subset",
|
||||
prefix: "10.0.0.1/32",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "overlaps with WG network - host in range",
|
||||
prefix: "10.0.0.100/32",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prefix, err := netip.ParsePrefix(tt.prefix)
|
||||
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
|
||||
|
||||
err = sysOps.validateRoute(prefix)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err, "validateRoute() expected error for %s", tt.prefix)
|
||||
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s", tt.prefix)
|
||||
} else {
|
||||
assert.NoError(t, err, "validateRoute() expected no error for %s", tt.prefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSysOps_validateRoute_SubnetOverlap(t *testing.T) {
|
||||
wgNetwork := netip.MustParsePrefix("192.168.100.0/24")
|
||||
mockWG := &mockWGIface{
|
||||
address: wgaddr.Address{
|
||||
IP: wgNetwork.Addr(),
|
||||
Network: wgNetwork,
|
||||
},
|
||||
name: "wg0",
|
||||
}
|
||||
|
||||
sysOps := &SysOps{
|
||||
wgInterface: mockWG,
|
||||
notifier: ¬ifier.Notifier{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "identical subnet",
|
||||
prefix: "192.168.100.0/24",
|
||||
expectError: true,
|
||||
description: "exact same network as WG interface",
|
||||
},
|
||||
{
|
||||
name: "broader subnet containing WG network",
|
||||
prefix: "192.168.0.0/16",
|
||||
expectError: false,
|
||||
description: "broader network that contains WG network should be allowed",
|
||||
},
|
||||
{
|
||||
name: "host within WG network",
|
||||
prefix: "192.168.100.50/32",
|
||||
expectError: true,
|
||||
description: "specific host within WG network",
|
||||
},
|
||||
{
|
||||
name: "subnet within WG network",
|
||||
prefix: "192.168.100.128/25",
|
||||
expectError: true,
|
||||
description: "smaller subnet within WG network",
|
||||
},
|
||||
{
|
||||
name: "adjacent subnet - same /23",
|
||||
prefix: "192.168.101.0/24",
|
||||
expectError: false,
|
||||
description: "adjacent subnet, no overlap",
|
||||
},
|
||||
{
|
||||
name: "adjacent subnet - different /16",
|
||||
prefix: "192.167.100.0/24",
|
||||
expectError: false,
|
||||
description: "different network, no overlap",
|
||||
},
|
||||
{
|
||||
name: "WG network broadcast address",
|
||||
prefix: "192.168.100.255/32",
|
||||
expectError: true,
|
||||
description: "broadcast address of WG network",
|
||||
},
|
||||
{
|
||||
name: "WG network first usable",
|
||||
prefix: "192.168.100.1/32",
|
||||
expectError: true,
|
||||
description: "first usable address in WG network",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prefix, err := netip.ParsePrefix(tt.prefix)
|
||||
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
|
||||
|
||||
err = sysOps.validateRoute(prefix)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err, "validateRoute() expected error for %s (%s)", tt.prefix, tt.description)
|
||||
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s (%s)", tt.prefix, tt.description)
|
||||
} else {
|
||||
assert.NoError(t, err, "validateRoute() expected no error for %s (%s)", tt.prefix, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
|
||||
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
|
||||
mockWG := &mockWGIface{
|
||||
address: wgaddr.Address{
|
||||
IP: wgNetwork.Addr(),
|
||||
Network: wgNetwork,
|
||||
},
|
||||
name: "wt0",
|
||||
}
|
||||
|
||||
sysOps := &SysOps{
|
||||
wgInterface: mockWG,
|
||||
notifier: ¬ifier.Notifier{},
|
||||
}
|
||||
|
||||
var invalidPrefix netip.Prefix
|
||||
err := sysOps.validateRoute(invalidPrefix)
|
||||
|
||||
require.Error(t, err, "validateRoute() expected error for invalid prefix")
|
||||
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for invalid prefix")
|
||||
}
|
@ -3,15 +3,19 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@ -26,48 +30,16 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeCmd("add", prefix, nexthop)
|
||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeCmd("delete", prefix, nexthop)
|
||||
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
|
||||
inet := "-inet"
|
||||
if prefix.Addr().Is6() {
|
||||
inet = "-inet6"
|
||||
}
|
||||
|
||||
network := prefix.String()
|
||||
if prefix.IsSingleIP() {
|
||||
network = prefix.Addr().String()
|
||||
}
|
||||
|
||||
args := []string{"-n", action, inet, network}
|
||||
if nexthop.IP.IsValid() {
|
||||
args = append(args, nexthop.IP.Unmap().String())
|
||||
} else if nexthop.Intf != nil {
|
||||
args = append(args, "-interface", nexthop.Intf.Name)
|
||||
}
|
||||
|
||||
if err := retryRouteCmd(args); err != nil {
|
||||
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func retryRouteCmd(args []string) error {
|
||||
operation := func() error {
|
||||
out, err := exec.Command("route", args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
// https://github.com/golang/go/issues/45736
|
||||
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
|
||||
return err
|
||||
} else if err != nil {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return nil
|
||||
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
|
||||
if !prefix.IsValid() {
|
||||
return fmt.Errorf("invalid prefix: %s", prefix)
|
||||
}
|
||||
|
||||
expBackOff := backoff.NewExponentialBackOff()
|
||||
@ -75,9 +47,157 @@ func retryRouteCmd(args []string) error {
|
||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||
|
||||
err := backoff.Retry(operation, expBackOff)
|
||||
if err != nil {
|
||||
return fmt.Errorf("route cmd retry failed: %w", err)
|
||||
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
|
||||
a := "add"
|
||||
if action == unix.RTM_DELETE {
|
||||
a = "remove"
|
||||
}
|
||||
return fmt.Errorf("%s route for %s: %w", a, prefix, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
|
||||
operation := func() error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("failed to close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
|
||||
}
|
||||
|
||||
msgBytes, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
|
||||
}
|
||||
|
||||
if _, err = unix.Write(fd, msgBytes); err != nil {
|
||||
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
||||
}
|
||||
|
||||
respBuf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, respBuf)
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return operation
|
||||
}
|
||||
|
||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||
msg = &route.RouteMessage{
|
||||
Type: action,
|
||||
Flags: unix.RTF_UP,
|
||||
Version: unix.RTM_VERSION,
|
||||
Seq: 1,
|
||||
}
|
||||
|
||||
const numAddrs = unix.RTAX_NETMASK + 1
|
||||
addrs := make([]route.Addr, numAddrs)
|
||||
|
||||
addrs[unix.RTAX_DST], err = addrToRouteAddr(prefix.Addr())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build destination address for %s: %w", prefix.Addr(), err)
|
||||
}
|
||||
|
||||
if prefix.IsSingleIP() {
|
||||
msg.Flags |= unix.RTF_HOST
|
||||
} else {
|
||||
addrs[unix.RTAX_NETMASK], err = prefixToRouteNetmask(prefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build netmask for %s: %w", prefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
if nexthop.IP.IsValid() {
|
||||
msg.Flags |= unix.RTF_GATEWAY
|
||||
addrs[unix.RTAX_GATEWAY], err = addrToRouteAddr(nexthop.IP.Unmap())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build gateway IP address for %s: %w", nexthop.IP, err)
|
||||
}
|
||||
} else if nexthop.Intf != nil {
|
||||
msg.Index = nexthop.Intf.Index
|
||||
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
|
||||
Index: nexthop.Intf.Index,
|
||||
Name: nexthop.Intf.Name,
|
||||
}
|
||||
}
|
||||
|
||||
msg.Addrs = addrs
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) parseRouteResponse(buf []byte) error {
|
||||
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
||||
return nil
|
||||
}
|
||||
|
||||
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
if rtMsg.Errno != 0 {
|
||||
return fmt.Errorf("parse: %d", rtMsg.Errno)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
|
||||
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
|
||||
if addr.Is4() {
|
||||
return &route.Inet4Addr{IP: addr.As4()}, nil
|
||||
}
|
||||
|
||||
if addr.Zone() == "" {
|
||||
return &route.Inet6Addr{IP: addr.As16()}, nil
|
||||
}
|
||||
|
||||
var zone int
|
||||
// zone can be either a numeric zone ID or an interface name.
|
||||
if z, err := strconv.Atoi(addr.Zone()); err == nil {
|
||||
zone = z
|
||||
} else {
|
||||
iface, err := net.InterfaceByName(addr.Zone())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve zone '%s': %w", addr.Zone(), err)
|
||||
}
|
||||
zone = iface.Index
|
||||
}
|
||||
return &route.Inet6Addr{IP: addr.As16(), ZoneID: zone}, nil
|
||||
}
|
||||
|
||||
func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
|
||||
bits := prefix.Bits()
|
||||
if prefix.Addr().Is4() {
|
||||
m := net.CIDRMask(bits, 32)
|
||||
var maskBytes [4]byte
|
||||
copy(maskBytes[:], m)
|
||||
return &route.Inet4Addr{IP: maskBytes}, nil
|
||||
}
|
||||
|
||||
if prefix.Addr().Is6() {
|
||||
m := net.CIDRMask(bits, 128)
|
||||
var maskBytes [16]byte
|
||||
copy(maskBytes[:], m)
|
||||
return &route.Inet6Addr{IP: maskBytes}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
|
||||
}
|
||||
|
@ -1,5 +1,3 @@
|
||||
//go:build windows
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
@ -9,9 +7,8 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
@ -21,11 +18,12 @@ import (
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const InfiniteLifetime = 0xffffffff
|
||||
|
||||
type RouteUpdateType int
|
||||
|
||||
// RouteUpdate represents a change in the routing table.
|
||||
@ -58,9 +56,13 @@ type MSFT_NetRoute struct {
|
||||
AddressFamily uint16
|
||||
}
|
||||
|
||||
// MIB_IPFORWARD_ROW2 is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
|
||||
// luid represents a locally unique identifier for network interfaces
|
||||
type luid uint64
|
||||
|
||||
// MIB_IPFORWARD_ROW2 represents a route entry in the routing table.
|
||||
// It is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
|
||||
type MIB_IPFORWARD_ROW2 struct {
|
||||
InterfaceLuid uint64
|
||||
InterfaceLuid luid
|
||||
InterfaceIndex uint32
|
||||
DestinationPrefix IP_ADDRESS_PREFIX
|
||||
NextHop SOCKADDR_INET_NEXTHOP
|
||||
@ -108,9 +110,14 @@ type SOCKADDR_INET_NEXTHOP struct {
|
||||
type MIB_NOTIFICATION_TYPE int32
|
||||
|
||||
var (
|
||||
modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
|
||||
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
|
||||
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
|
||||
modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
|
||||
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
|
||||
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
|
||||
procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2")
|
||||
procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2")
|
||||
procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2")
|
||||
procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry")
|
||||
procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid")
|
||||
|
||||
prefixList []netip.Prefix
|
||||
lastUpdate time.Time
|
||||
@ -139,6 +146,8 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
log.Debugf("Adding route to %s via %s", prefix, nexthop)
|
||||
// if we don't have an interface but a zone, extract the interface index from the zone
|
||||
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
|
||||
zone, err := strconv.Atoi(nexthop.IP.Zone())
|
||||
if err != nil {
|
||||
@ -147,23 +156,187 @@ func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
nexthop.Intf = &net.Interface{Index: zone}
|
||||
}
|
||||
|
||||
return addRouteCmd(prefix, nexthop)
|
||||
return addRoute(prefix, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if nexthop.IP.IsValid() {
|
||||
ip := nexthop.IP.WithZone("")
|
||||
args = append(args, ip.Unmap().String())
|
||||
log.Debugf("Removing route to %s via %s", prefix, nexthop)
|
||||
return deleteRoute(prefix, nexthop)
|
||||
}
|
||||
|
||||
// setupRouteEntry prepares a route entry with common configuration
|
||||
func setupRouteEntry(prefix netip.Prefix, nexthop Nexthop) (*MIB_IPFORWARD_ROW2, error) {
|
||||
route := &MIB_IPFORWARD_ROW2{}
|
||||
|
||||
initializeIPForwardEntry(route)
|
||||
|
||||
// Convert interface index to luid if interface is specified
|
||||
if nexthop.Intf != nil {
|
||||
var luid luid
|
||||
if err := convertInterfaceIndexToLUID(uint32(nexthop.Intf.Index), &luid); err != nil {
|
||||
return nil, fmt.Errorf("convert interface index to luid: %w", err)
|
||||
}
|
||||
route.InterfaceLuid = luid
|
||||
route.InterfaceIndex = uint32(nexthop.Intf.Index)
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
if err := setDestinationPrefix(&route.DestinationPrefix, prefix); err != nil {
|
||||
return nil, fmt.Errorf("set destination prefix: %w", err)
|
||||
}
|
||||
|
||||
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
if nexthop.IP.IsValid() {
|
||||
if err := setNextHop(&route.NextHop, nexthop.IP); err != nil {
|
||||
return nil, fmt.Errorf("set next hop: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// addRoute adds a route using Windows iphelper APIs
|
||||
func addRoute(prefix netip.Prefix, nexthop Nexthop) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in addRoute: %v, stack trace: %s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
route, setupErr := setupRouteEntry(prefix, nexthop)
|
||||
if setupErr != nil {
|
||||
return fmt.Errorf("setup route entry: %w", setupErr)
|
||||
}
|
||||
|
||||
route.Metric = 1
|
||||
route.ValidLifetime = InfiniteLifetime
|
||||
route.PreferredLifetime = InfiniteLifetime
|
||||
|
||||
return createIPForwardEntry2(route)
|
||||
}
|
||||
|
||||
// deleteRoute deletes a route using Windows iphelper APIs
|
||||
func deleteRoute(prefix netip.Prefix, nexthop Nexthop) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in deleteRoute: %v, stack trace: %s", r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
route, setupErr := setupRouteEntry(prefix, nexthop)
|
||||
if setupErr != nil {
|
||||
return fmt.Errorf("setup route entry: %w", setupErr)
|
||||
}
|
||||
|
||||
if err := getIPForwardEntry2(route); err != nil {
|
||||
return fmt.Errorf("get route entry: %w", err)
|
||||
}
|
||||
|
||||
return deleteIPForwardEntry2(route)
|
||||
}
|
||||
|
||||
// setDestinationPrefix sets the destination prefix in the route structure
|
||||
func setDestinationPrefix(prefix *IP_ADDRESS_PREFIX, dest netip.Prefix) error {
|
||||
addr := dest.Addr()
|
||||
prefix.PrefixLength = uint8(dest.Bits())
|
||||
|
||||
if addr.Is4() {
|
||||
prefix.Prefix.sin6_family = windows.AF_INET
|
||||
ip4 := addr.As4()
|
||||
binary.BigEndian.PutUint32(prefix.Prefix.data[:4],
|
||||
uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3]))
|
||||
return nil
|
||||
}
|
||||
|
||||
if addr.Is6() {
|
||||
prefix.Prefix.sin6_family = windows.AF_INET6
|
||||
ip6 := addr.As16()
|
||||
copy(prefix.Prefix.data[4:20], ip6[:])
|
||||
|
||||
if zone := addr.Zone(); zone != "" {
|
||||
if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil {
|
||||
binary.BigEndian.PutUint32(prefix.Prefix.data[20:24], uint32(scopeID))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid address family")
|
||||
}
|
||||
|
||||
// setNextHop sets the next hop address in the route structure
|
||||
func setNextHop(nextHop *SOCKADDR_INET_NEXTHOP, addr netip.Addr) error {
|
||||
if addr.Is4() {
|
||||
nextHop.sin6_family = windows.AF_INET
|
||||
ip4 := addr.As4()
|
||||
binary.BigEndian.PutUint32(nextHop.data[:4],
|
||||
uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3]))
|
||||
return nil
|
||||
}
|
||||
|
||||
if addr.Is6() {
|
||||
nextHop.sin6_family = windows.AF_INET6
|
||||
ip6 := addr.As16()
|
||||
copy(nextHop.data[4:20], ip6[:])
|
||||
|
||||
// Handle zone if present
|
||||
if zone := addr.Zone(); zone != "" {
|
||||
if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil {
|
||||
binary.BigEndian.PutUint32(nextHop.data[20:24], uint32(scopeID))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid address family")
|
||||
}
|
||||
|
||||
// Windows API wrappers
|
||||
func createIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
|
||||
r1, _, e1 := syscall.SyscallN(procCreateIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
|
||||
if r1 != 0 {
|
||||
if e1 != 0 {
|
||||
return fmt.Errorf("CreateIpForwardEntry2: %w", e1)
|
||||
}
|
||||
return fmt.Errorf("CreateIpForwardEntry2: code %d", r1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
|
||||
r1, _, e1 := syscall.SyscallN(procDeleteIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
|
||||
if r1 != 0 {
|
||||
if e1 != 0 {
|
||||
return fmt.Errorf("DeleteIpForwardEntry2: %w", e1)
|
||||
}
|
||||
return fmt.Errorf("DeleteIpForwardEntry2: code %d", r1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
|
||||
r1, _, e1 := syscall.SyscallN(procGetIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
|
||||
if r1 != 0 {
|
||||
if e1 != 0 {
|
||||
return fmt.Errorf("GetIpForwardEntry2: %w", e1)
|
||||
}
|
||||
return fmt.Errorf("GetIpForwardEntry2: code %d", r1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/netioapi/nf-netioapi-initializeipforwardentry
|
||||
func initializeIPForwardEntry(route *MIB_IPFORWARD_ROW2) {
|
||||
// Does not return anything. Trying to handle the error might return an uninitialized value.
|
||||
_, _, _ = syscall.SyscallN(procInitializeIpForwardEntry.Addr(), uintptr(unsafe.Pointer(route)))
|
||||
}
|
||||
|
||||
func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *luid) error {
|
||||
r1, _, e1 := syscall.SyscallN(procConvertInterfaceIndexToLuid.Addr(),
|
||||
uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)))
|
||||
if r1 != 0 {
|
||||
if e1 != 0 {
|
||||
return fmt.Errorf("ConvertInterfaceIndexToLuid: %w", e1)
|
||||
}
|
||||
return fmt.Errorf("ConvertInterfaceIndexToLuid: code %d", r1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -319,7 +492,7 @@ func cancelMibChangeNotify2(handle windows.Handle) error {
|
||||
}
|
||||
|
||||
// GetRoutesFromTable returns the current routing table from with prefixes only.
|
||||
// It ccaches the result for 2 seconds to avoid blocking the caller.
|
||||
// It caches the result for 2 seconds to avoid blocking the caller.
|
||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
@ -388,35 +561,6 @@ func GetRoutes() ([]Route, error) {
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
args := []string{"add", prefix.String()}
|
||||
|
||||
if nexthop.IP.IsValid() {
|
||||
ip := nexthop.IP.WithZone("")
|
||||
args = append(args, ip.Unmap().String())
|
||||
} else {
|
||||
addr := "0.0.0.0"
|
||||
if prefix.Addr().Is6() {
|
||||
addr = "::"
|
||||
}
|
||||
args = append(args, addr)
|
||||
}
|
||||
|
||||
if nexthop.Intf != nil {
|
||||
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
|
||||
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("route add: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isCacheDisabled() bool {
|
||||
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
|
||||
}
|
||||
|
@ -5,18 +5,23 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
var expectedExtInt = "Ethernet1"
|
||||
var (
|
||||
expectedExternalInt = "Ethernet1"
|
||||
expectedVPNint = "wgtest0"
|
||||
)
|
||||
|
||||
type RouteInfo struct {
|
||||
NextHop string `json:"nexthop"`
|
||||
@ -43,8 +48,6 @@ type testCase struct {
|
||||
dialer dialer
|
||||
}
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
|
||||
var testCases = []testCase{
|
||||
{
|
||||
name: "To external host without custom dialer via vpn",
|
||||
@ -52,14 +55,14 @@ var testCases = []testCase{
|
||||
expectedSourceIP: "100.64.0.1",
|
||||
expectedDestPrefix: "128.0.0.0/1",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "wgtest0",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
},
|
||||
{
|
||||
name: "To external host with custom dialer via physical interface",
|
||||
destination: "192.0.2.1:53",
|
||||
expectedDestPrefix: "192.0.2.1/32",
|
||||
expectedInterface: expectedExtInt,
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
},
|
||||
|
||||
@ -67,24 +70,15 @@ var testCases = []testCase{
|
||||
name: "To duplicate internal route with custom dialer via physical interface",
|
||||
destination: "10.0.0.2:53",
|
||||
expectedDestPrefix: "10.0.0.2/32",
|
||||
expectedInterface: expectedExtInt,
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
},
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
destination: "10.0.0.2:53",
|
||||
expectedSourceIP: "127.0.0.1",
|
||||
expectedDestPrefix: "10.0.0.0/8",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "Loopback Pseudo-Interface 1",
|
||||
dialer: &net.Dialer{},
|
||||
},
|
||||
|
||||
{
|
||||
name: "To unique vpn route with custom dialer via physical interface",
|
||||
destination: "172.16.0.2:53",
|
||||
expectedDestPrefix: "172.16.0.2/32",
|
||||
expectedInterface: expectedExtInt,
|
||||
expectedInterface: expectedExternalInt,
|
||||
dialer: nbnet.NewDialer(),
|
||||
},
|
||||
{
|
||||
@ -93,7 +87,7 @@ var testCases = []testCase{
|
||||
expectedSourceIP: "100.64.0.1",
|
||||
expectedDestPrefix: "172.16.0.0/12",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "wgtest0",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
},
|
||||
|
||||
@ -103,22 +97,14 @@ var testCases = []testCase{
|
||||
expectedSourceIP: "100.64.0.1",
|
||||
expectedDestPrefix: "10.10.0.0/24",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "wgtest0",
|
||||
dialer: &net.Dialer{},
|
||||
},
|
||||
|
||||
{
|
||||
name: "To more specific route (local) without custom dialer via physical interface",
|
||||
destination: "127.0.10.2:53",
|
||||
expectedSourceIP: "127.0.0.1",
|
||||
expectedDestPrefix: "127.0.0.0/8",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "Loopback Pseudo-Interface 1",
|
||||
expectedInterface: expectedVPNint,
|
||||
dialer: &net.Dialer{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setupTestEnv(t)
|
||||
@ -129,7 +115,7 @@ func TestRouting(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to fetch interface IP")
|
||||
|
||||
output := testRoute(t, tc.destination, tc.dialer)
|
||||
if tc.expectedInterface == expectedExtInt {
|
||||
if tc.expectedInterface == expectedExternalInt {
|
||||
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
|
||||
} else {
|
||||
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
|
||||
@ -242,19 +228,23 @@ func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
func addDummyRoute(t *testing.T, dstCIDR string) {
|
||||
t.Helper()
|
||||
|
||||
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR)
|
||||
|
||||
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
prefix, err := netip.ParsePrefix(dstCIDR)
|
||||
if err != nil {
|
||||
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output)
|
||||
t.FailNow()
|
||||
t.Fatalf("Failed to parse destination CIDR %s: %v", dstCIDR, err)
|
||||
}
|
||||
|
||||
nexthop := Nexthop{
|
||||
Intf: &net.Interface{Index: 1},
|
||||
}
|
||||
|
||||
if err = addRoute(prefix, nexthop); err != nil {
|
||||
t.Fatalf("Failed to add dummy route: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR)
|
||||
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
err := deleteRoute(prefix, nexthop)
|
||||
if err != nil {
|
||||
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output)
|
||||
t.Logf("Failed to remove dummy route: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -3,11 +3,11 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
@ -19,81 +19,32 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("connect client not initialized")
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("engine not initialized")
|
||||
tracer, engine, err := s.getPacketTracer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager == nil {
|
||||
return nil, fmt.Errorf("firewall manager not initialized")
|
||||
srcAddr, err := s.parseAddress(req.GetSourceIp(), engine)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid source IP address: %w", err)
|
||||
}
|
||||
|
||||
tracer, ok := fwManager.(packetTracer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("firewall manager does not support packet tracing")
|
||||
dstAddr, err := s.parseAddress(req.GetDestinationIp(), engine)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid destination IP address: %w", err)
|
||||
}
|
||||
|
||||
srcIP := net.ParseIP(req.GetSourceIp())
|
||||
if req.GetSourceIp() == "self" {
|
||||
srcIP = engine.GetWgAddr()
|
||||
protocol, err := s.parseProtocol(req.GetProtocol())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
srcAddr, ok := netip.AddrFromSlice(srcIP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source IP address")
|
||||
direction, err := s.parseDirection(req.GetDirection())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dstIP := net.ParseIP(req.GetDestinationIp())
|
||||
if req.GetDestinationIp() == "self" {
|
||||
dstIP = engine.GetWgAddr()
|
||||
}
|
||||
|
||||
dstAddr, ok := netip.AddrFromSlice(dstIP)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source IP address")
|
||||
}
|
||||
|
||||
if srcIP == nil || dstIP == nil {
|
||||
return nil, fmt.Errorf("invalid IP address")
|
||||
}
|
||||
|
||||
var tcpState *uspfilter.TCPState
|
||||
if flags := req.GetTcpFlags(); flags != nil {
|
||||
tcpState = &uspfilter.TCPState{
|
||||
SYN: flags.GetSyn(),
|
||||
ACK: flags.GetAck(),
|
||||
FIN: flags.GetFin(),
|
||||
RST: flags.GetRst(),
|
||||
PSH: flags.GetPsh(),
|
||||
URG: flags.GetUrg(),
|
||||
}
|
||||
}
|
||||
|
||||
var dir fw.RuleDirection
|
||||
switch req.GetDirection() {
|
||||
case "in":
|
||||
dir = fw.RuleDirectionIN
|
||||
case "out":
|
||||
dir = fw.RuleDirectionOUT
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid direction")
|
||||
}
|
||||
|
||||
var protocol fw.Protocol
|
||||
switch req.GetProtocol() {
|
||||
case "tcp":
|
||||
protocol = fw.ProtocolTCP
|
||||
case "udp":
|
||||
protocol = fw.ProtocolUDP
|
||||
case "icmp":
|
||||
protocol = fw.ProtocolICMP
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid protocolcol")
|
||||
}
|
||||
tcpState := s.parseTCPFlags(req.GetTcpFlags())
|
||||
|
||||
builder := &uspfilter.PacketBuilder{
|
||||
SrcIP: srcAddr,
|
||||
@ -101,16 +52,96 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
Protocol: protocol,
|
||||
SrcPort: uint16(req.GetSourcePort()),
|
||||
DstPort: uint16(req.GetDestinationPort()),
|
||||
Direction: dir,
|
||||
Direction: direction,
|
||||
TCPState: tcpState,
|
||||
ICMPType: uint8(req.GetIcmpType()),
|
||||
ICMPCode: uint8(req.GetIcmpCode()),
|
||||
}
|
||||
|
||||
trace, err := tracer.TracePacketFromBuilder(builder)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("trace packet: %w", err)
|
||||
}
|
||||
|
||||
return s.buildTraceResponse(trace), nil
|
||||
}
|
||||
|
||||
func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, nil, fmt.Errorf("connect client not initialized")
|
||||
}
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, nil, fmt.Errorf("engine not initialized")
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager == nil {
|
||||
return nil, nil, fmt.Errorf("firewall manager not initialized")
|
||||
}
|
||||
|
||||
tracer, ok := fwManager.(packetTracer)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("firewall manager does not support packet tracing")
|
||||
}
|
||||
|
||||
return tracer, engine, nil
|
||||
}
|
||||
|
||||
func (s *Server) parseAddress(addr string, engine *internal.Engine) (netip.Addr, error) {
|
||||
if addr == "self" {
|
||||
return engine.GetWgAddr(), nil
|
||||
}
|
||||
|
||||
a, err := netip.ParseAddr(addr)
|
||||
if err != nil {
|
||||
return netip.Addr{}, err
|
||||
}
|
||||
|
||||
return a.Unmap(), nil
|
||||
}
|
||||
|
||||
func (s *Server) parseProtocol(protocol string) (fw.Protocol, error) {
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
return fw.ProtocolTCP, nil
|
||||
case "udp":
|
||||
return fw.ProtocolUDP, nil
|
||||
case "icmp":
|
||||
return fw.ProtocolICMP, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid protocol")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) parseDirection(direction string) (fw.RuleDirection, error) {
|
||||
switch direction {
|
||||
case "in":
|
||||
return fw.RuleDirectionIN, nil
|
||||
case "out":
|
||||
return fw.RuleDirectionOUT, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid direction")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) parseTCPFlags(flags *proto.TCPFlags) *uspfilter.TCPState {
|
||||
if flags == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &uspfilter.TCPState{
|
||||
SYN: flags.GetSyn(),
|
||||
ACK: flags.GetAck(),
|
||||
FIN: flags.GetFin(),
|
||||
RST: flags.GetRst(),
|
||||
PSH: flags.GetPsh(),
|
||||
URG: flags.GetUrg(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) buildTraceResponse(trace *uspfilter.PacketTrace) *proto.TracePacketResponse {
|
||||
resp := &proto.TracePacketResponse{}
|
||||
|
||||
for _, result := range trace.Results {
|
||||
@ -119,10 +150,12 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
Message: result.Message,
|
||||
Allowed: result.Allowed,
|
||||
}
|
||||
|
||||
if result.ForwarderAction != nil {
|
||||
details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr)
|
||||
stage.ForwardingDetails = &details
|
||||
}
|
||||
|
||||
resp.Stages = append(resp.Stages, stage)
|
||||
}
|
||||
|
||||
@ -130,5 +163,5 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
||||
resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return resp
|
||||
}
|
||||
|
@ -1,8 +1,10 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@ -54,11 +56,13 @@ func GenerateConnID() ConnectionID {
|
||||
return ConnectionID(uuid.NewString())
|
||||
}
|
||||
|
||||
func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP {
|
||||
// Calculate the last IP in the CIDR range
|
||||
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
|
||||
var endIP net.IP
|
||||
for i := 0; i < len(network.IP); i++ {
|
||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||
addr := network.Addr().AsSlice()
|
||||
mask := net.CIDRMask(network.Bits(), len(addr)*8)
|
||||
|
||||
for i := 0; i < len(addr); i++ {
|
||||
endIP = append(endIP, addr[i]|^mask[i])
|
||||
}
|
||||
|
||||
// convert to big.Int
|
||||
@ -70,5 +74,10 @@ func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP {
|
||||
resultInt := big.NewInt(0)
|
||||
resultInt.Sub(endInt, fromEndBig)
|
||||
|
||||
return resultInt.Bytes()
|
||||
ip, ok := netip.AddrFromSlice(resultInt.Bytes())
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP address from network %s", network)
|
||||
}
|
||||
|
||||
return ip.Unmap(), nil
|
||||
}
|
||||
|
94
util/net/net_test.go
Normal file
94
util/net/net_test.go
Normal file
@ -0,0 +1,94 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
network string
|
||||
fromEnd int
|
||||
expected string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 /24 network - last IP (fromEnd=0)",
|
||||
network: "192.168.1.0/24",
|
||||
fromEnd: 0,
|
||||
expected: "192.168.1.255",
|
||||
},
|
||||
{
|
||||
name: "IPv4 /24 network - fromEnd=1",
|
||||
network: "192.168.1.0/24",
|
||||
fromEnd: 1,
|
||||
expected: "192.168.1.254",
|
||||
},
|
||||
{
|
||||
name: "IPv4 /24 network - fromEnd=5",
|
||||
network: "192.168.1.0/24",
|
||||
fromEnd: 5,
|
||||
expected: "192.168.1.250",
|
||||
},
|
||||
{
|
||||
name: "IPv4 /16 network - last IP",
|
||||
network: "10.0.0.0/16",
|
||||
fromEnd: 0,
|
||||
expected: "10.0.255.255",
|
||||
},
|
||||
{
|
||||
name: "IPv4 /16 network - fromEnd=256",
|
||||
network: "10.0.0.0/16",
|
||||
fromEnd: 256,
|
||||
expected: "10.0.254.255",
|
||||
},
|
||||
{
|
||||
name: "IPv4 /32 network - single host",
|
||||
network: "192.168.1.100/32",
|
||||
fromEnd: 0,
|
||||
expected: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "IPv6 /64 network - last IP",
|
||||
network: "2001:db8::/64",
|
||||
fromEnd: 0,
|
||||
expected: "2001:db8::ffff:ffff:ffff:ffff",
|
||||
},
|
||||
{
|
||||
name: "IPv6 /64 network - fromEnd=1",
|
||||
network: "2001:db8::/64",
|
||||
fromEnd: 1,
|
||||
expected: "2001:db8::ffff:ffff:ffff:fffe",
|
||||
},
|
||||
{
|
||||
name: "IPv6 /128 network - single host",
|
||||
network: "2001:db8::1/128",
|
||||
fromEnd: 0,
|
||||
expected: "2001:db8::1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
network, err := netip.ParsePrefix(tt.network)
|
||||
require.NoError(t, err, "Failed to parse network prefix")
|
||||
|
||||
result, err := GetLastIPFromNetwork(network, tt.fromEnd)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedIP, err := netip.ParseAddr(tt.expected)
|
||||
require.NoError(t, err, "Failed to parse expected IP")
|
||||
|
||||
assert.Equal(t, expectedIP, result, "IP mismatch for network %s with fromEnd=%d", tt.network, tt.fromEnd)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user