From ea4d13e96d79665bc0e4b489af5a0d492e104515 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 4 Jun 2025 16:28:58 +0200 Subject: [PATCH] [client] Use platform-native routing APIs for freeBSD, macOS and Windows --- client/cmd/trace.go | 2 +- .../firewall/iptables/manager_linux_test.go | 39 +- .../firewall/nftables/manager_linux_test.go | 31 +- .../firewall/uspfilter/forwarder/forwarder.go | 12 +- client/firewall/uspfilter/localip.go | 48 +- client/firewall/uspfilter/localip_test.go | 49 +- client/firewall/uspfilter/tracer_test.go | 7 +- client/firewall/uspfilter/uspfilter.go | 6 - .../uspfilter/uspfilter_bench_test.go | 74 +- .../uspfilter/uspfilter_filter_test.go | 22 +- client/firewall/uspfilter/uspfilter_test.go | 20 +- client/iface/bind/udp_mux_universal.go | 2 +- client/iface/device/device_filter.go | 4 - client/iface/device/device_netstack.go | 6 +- client/iface/device/wg_link_freebsd.go | 10 +- client/iface/iface.go | 1 - client/iface/mocks/filter.go | 13 - client/iface/netstack/tun.go | 22 +- client/iface/wgaddr/address.go | 15 +- client/internal/acl/manager_test.go | 17 +- client/internal/dns.go | 34 +- client/internal/dns/server_test.go | 12 +- client/internal/dns/service_memory.go | 8 +- client/internal/dns/service_memory_test.go | 33 - client/internal/dns/upstream_android.go | 5 +- client/internal/dns/upstream_general.go | 6 +- client/internal/dns/upstream_ios.go | 20 +- client/internal/dns/upstream_test.go | 4 +- client/internal/engine.go | 17 +- client/internal/engine_test.go | 7 +- .../internal/netflow/conntrack/conntrack.go | 12 +- client/internal/netflow/logger/logger.go | 13 +- client/internal/netflow/logger/logger_test.go | 4 +- client/internal/netflow/manager.go | 8 +- client/internal/netflow/manager_test.go | 12 +- .../routemanager/dnsinterceptor/handler.go | 2 +- client/internal/routemanager/manager.go | 9 +- client/internal/routemanager/manager_test.go | 30 +- .../routemanager/sysctl/sysctl_linux.go | 9 +- .../routemanager/systemops/systemops.go | 30 +- .../systemops/systemops_bsd_test.go | 78 ++- .../systemops/systemops_generic.go | 121 +--- .../systemops/systemops_generic_test.go | 635 ++++++++++-------- .../routemanager/systemops/systemops_linux.go | 14 +- .../systemops/systemops_linux_test.go | 7 - .../systemops/systemops_nonlinux.go | 6 + .../routemanager/systemops/systemops_test.go | 268 ++++++++ .../routemanager/systemops/systemops_unix.go | 204 ++++-- .../systemops/systemops_windows.go | 244 +++++-- .../systemops/systemops_windows_test.go | 64 +- client/server/trace.go | 169 +++-- util/net/net.go | 19 +- util/net/net_test.go | 94 +++ 53 files changed, 1552 insertions(+), 1046 deletions(-) delete mode 100644 client/internal/dns/service_memory_test.go create mode 100644 client/internal/routemanager/systemops/systemops_test.go create mode 100644 util/net/net_test.go diff --git a/client/cmd/trace.go b/client/cmd/trace.go index b2ff1f1b5..abb73b646 100644 --- a/client/cmd/trace.go +++ b/client/cmd/trace.go @@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{ Example: ` netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53 - netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0 + netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0 netbird debug trace in 100.64.1.1 self -p tcp --dport 80`, Args: cobra.ExactArgs(3), RunE: tracePacket, diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index af9f5dd23..30f391a6d 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -2,7 +2,7 @@ package iptables import ( "fmt" - "net" + "net/netip" "testing" "time" @@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{ }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ - IP: net.ParseIP("10.20.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("10.20.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, + IP: netip.MustParseAddr("10.20.0.1"), + Network: netip.MustParsePrefix("10.20.0.0/24"), } }, } @@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) { var rule2 []fw.Rule t.Run("add second rule", func(t *testing.T) { - ip := net.ParseIP("10.20.0.3") + ip := netip.MustParseAddr("10.20.0.3") port := &fw.Port{ IsRange: true, Values: []uint16{8043, 8046}, } - rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "") + rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "") require.NoError(t, err, "failed to add rule") for _, r := range rule2 { @@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) { t.Run("reset check", func(t *testing.T) { // add second rule - ip := net.ParseIP("10.20.0.3") + ip := netip.MustParseAddr("10.20.0.3") port := &fw.Port{Values: []uint16{5353}} - _, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "") + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "") require.NoError(t, err, "failed to add rule") err = manager.Close(nil) @@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) { }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ - IP: net.ParseIP("10.20.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("10.20.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, + IP: netip.MustParseAddr("10.20.0.1"), + Network: netip.MustParsePrefix("10.20.0.0/24"), } }, } @@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) { var rule2 []fw.Rule t.Run("add second rule", func(t *testing.T) { - ip := net.ParseIP("10.20.0.3") + ip := netip.MustParseAddr("10.20.0.3") port := &fw.Port{ Values: []uint16{443}, } - rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default") + rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default") for _, r := range rule2 { require.NoError(t, err, "failed to add rule") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") @@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) { }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ - IP: net.ParseIP("10.20.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("10.20.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, + IP: netip.MustParseAddr("10.20.0.1"), + Network: netip.MustParsePrefix("10.20.0.0/24"), } }, } @@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) { require.NoError(t, err) - ip := net.ParseIP("10.20.0.100") + ip := netip.MustParseAddr("10.20.0.100") start := time.Now() for i := 0; i < testMax; i++ { port := &fw.Port{Values: []uint16{uint16(1000 + i)}} - _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "") require.NoError(t, err, "failed to add rule") } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 602a6b8dc..1dd3e9183 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -3,7 +3,6 @@ package nftables import ( "bytes" "fmt" - "net" "net/netip" "os/exec" "testing" @@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{ }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ - IP: net.ParseIP("100.96.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("100.96.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, + IP: netip.MustParseAddr("100.96.0.1"), + Network: netip.MustParsePrefix("100.96.0.0/16"), } }, } @@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) { time.Sleep(time.Second) }() - ip := net.ParseIP("100.96.0.1") + ip := netip.MustParseAddr("100.96.0.1").Unmap() testClient := &nftables.Conn{} - rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "") + rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "") require.NoError(t, err, "failed to add rule") err = manager.Flush() @@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) { } compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) - ipToAdd, _ := netip.AddrFromSlice(ip) - add := ipToAdd.Unmap() expectedExprs2 := []expr.Any{ &expr.Payload{ DestRegister: 1, @@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) { &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, - Data: add.AsSlice(), + Data: ip.AsSlice(), }, &expr.Payload{ DestRegister: 1, @@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) { }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ - IP: net.ParseIP("100.96.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("100.96.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, + IP: netip.MustParseAddr("100.96.0.1"), + Network: netip.MustParsePrefix("100.96.0.0/16"), } }, } @@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) { time.Sleep(time.Second) }() - ip := net.ParseIP("10.20.0.100") + ip := netip.MustParseAddr("10.20.0.100") start := time.Now() for i := 0; i < testMax; i++ { port := &fw.Port{Values: []uint16{uint16(1000 + i)}} - _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "") require.NoError(t, err, "failed to add rule") if i%100 == 0 { @@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { verifyIptablesOutput(t, stdout, stderr) }) - ip := net.ParseIP("100.96.0.1") - _, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") + ip := netip.MustParseAddr("100.96.0.1") + _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") require.NoError(t, err, "failed to add peer filtering rule") _, err = manager.AddRouteFiltering( diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 2ae983f6e..42a3e0800 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -41,7 +41,7 @@ type Forwarder struct { udpForwarder *udpForwarder ctx context.Context cancel context.CancelFunc - ip net.IP + ip tcpip.Address netstack bool } @@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return nil, fmt.Errorf("failed to create NIC: %v", err) } - ones, _ := iface.Address().Network.Mask.Size() protoAddr := tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.AddrFromSlice(iface.Address().IP.To4()), - PrefixLen: ones, + Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), + PrefixLen: iface.Address().Network.Bits(), }, } @@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow ctx: ctx, cancel: cancel, netstack: netstack, - ip: iface.Address().IP, + ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), } receiveWindow := defaultReceiveWindow @@ -167,7 +166,7 @@ func (f *Forwarder) Stop() { } func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { - if f.netstack && f.ip.Equal(addr.AsSlice()) { + if f.netstack && f.ip.Equal(addr) { return net.IPv4(127, 0, 0, 1) } return addr.AsSlice() @@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin } func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) { - if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { return value.([]byte), true } else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok { diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go index f093f3429..7f6b52c71 100644 --- a/client/firewall/uspfilter/localip.go +++ b/client/firewall/uspfilter/localip.go @@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) { m.ipv4Bitmap[high].bitmap[index] |= 1 << bit } -func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { - if ipv4 := ip.To4(); ipv4 != nil { - high := uint16(ipv4[0]) - low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) +func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { + if !ip.Is4() { + return + } + ipv4 := ip.AsSlice() - if bitmap[high] == nil { - bitmap[high] = &ipv4LowBitmap{} - } + high := uint16(ipv4[0]) + low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) - index := low / 32 - bit := low % 32 - bitmap[high].bitmap[index] |= 1 << bit + if bitmap[high] == nil { + bitmap[high] = &ipv4LowBitmap{} + } - ipStr := ipv4.String() - if _, exists := ipv4Set[ipStr]; !exists { - ipv4Set[ipStr] = struct{}{} - *ipv4Addresses = append(*ipv4Addresses, ipStr) - } + index := low / 32 + bit := low % 32 + bitmap[high].bitmap[index] |= 1 << bit + + if _, exists := ipv4Set[ip]; !exists { + ipv4Set[ip] = struct{}{} + *ipv4Addresses = append(*ipv4Addresses, ip) } } @@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool { return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0 } -func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { +func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error { m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) return nil } -func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { +func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) { addrs, err := iface.Addrs() if err != nil { log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) @@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv continue } - if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name) + continue + } + + if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil { log.Debugf("process IP failed: %v", err) } } @@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { }() var newIPv4Bitmap [256]*ipv4LowBitmap - ipv4Set := make(map[string]struct{}) - var ipv4Addresses []string + ipv4Set := make(map[netip.Addr]struct{}) + var ipv4Addresses []netip.Addr // 127.0.0.0/8 newIPv4Bitmap[127] = &ipv4LowBitmap{} diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go index 0104c9603..45ac912cd 100644 --- a/client/firewall/uspfilter/localip_test.go +++ b/client/firewall/uspfilter/localip_test.go @@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) { { name: "Localhost range", setupAddr: wgaddr.Address{ - IP: net.ParseIP("192.168.1.1"), - Network: &net.IPNet{ - IP: net.ParseIP("192.168.1.0"), - Mask: net.CIDRMask(24, 32), - }, + IP: netip.MustParseAddr("192.168.1.1"), + Network: netip.MustParsePrefix("192.168.1.0/24"), }, testIP: netip.MustParseAddr("127.0.0.2"), expected: true, @@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) { { name: "Localhost standard address", setupAddr: wgaddr.Address{ - IP: net.ParseIP("192.168.1.1"), - Network: &net.IPNet{ - IP: net.ParseIP("192.168.1.0"), - Mask: net.CIDRMask(24, 32), - }, + IP: netip.MustParseAddr("192.168.1.1"), + Network: netip.MustParsePrefix("192.168.1.0/24"), }, testIP: netip.MustParseAddr("127.0.0.1"), expected: true, @@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) { { name: "Localhost range edge", setupAddr: wgaddr.Address{ - IP: net.ParseIP("192.168.1.1"), - Network: &net.IPNet{ - IP: net.ParseIP("192.168.1.0"), - Mask: net.CIDRMask(24, 32), - }, + IP: netip.MustParseAddr("192.168.1.1"), + Network: netip.MustParsePrefix("192.168.1.0/24"), }, testIP: netip.MustParseAddr("127.255.255.255"), expected: true, @@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) { { name: "Local IP matches", setupAddr: wgaddr.Address{ - IP: net.ParseIP("192.168.1.1"), - Network: &net.IPNet{ - IP: net.ParseIP("192.168.1.0"), - Mask: net.CIDRMask(24, 32), - }, + IP: netip.MustParseAddr("192.168.1.1"), + Network: netip.MustParsePrefix("192.168.1.0/24"), }, testIP: netip.MustParseAddr("192.168.1.1"), expected: true, @@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) { { name: "Local IP doesn't match", setupAddr: wgaddr.Address{ - IP: net.ParseIP("192.168.1.1"), - Network: &net.IPNet{ - IP: net.ParseIP("192.168.1.0"), - Mask: net.CIDRMask(24, 32), - }, + IP: netip.MustParseAddr("192.168.1.1"), + Network: netip.MustParsePrefix("192.168.1.0/24"), }, testIP: netip.MustParseAddr("192.168.1.2"), expected: false, @@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) { { name: "Local IP doesn't match - addresses 32 apart", setupAddr: wgaddr.Address{ - IP: net.ParseIP("192.168.1.1"), - Network: &net.IPNet{ - IP: net.ParseIP("192.168.1.0"), - Mask: net.CIDRMask(24, 32), - }, + IP: netip.MustParseAddr("192.168.1.1"), + Network: netip.MustParsePrefix("192.168.1.0/24"), }, testIP: netip.MustParseAddr("192.168.1.33"), expected: false, @@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) { { name: "IPv6 address", setupAddr: wgaddr.Address{ - IP: net.ParseIP("fe80::1"), - Network: &net.IPNet{ - IP: net.ParseIP("fe80::"), - Mask: net.CIDRMask(64, 128), - }, + IP: netip.MustParseAddr("fe80::1"), + Network: netip.MustParsePrefix("192.168.1.0/24"), }, testIP: netip.MustParseAddr("fe80::1"), expected: false, diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index bd87879a5..46c115787 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ - IP: net.ParseIP("100.10.0.100"), - Network: &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - }, + IP: netip.MustParseAddr("100.10.0.100"), + Network: netip.MustParsePrefix("100.10.0.0/16"), } }, } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 8e0a955ca..eede1ab13 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -71,7 +71,6 @@ type Manager struct { // incomingRules is used for filtering and hooks incomingRules map[netip.Addr]RuleSet routeRules RouteRules - wgNetwork *net.IPNet decoders sync.Pool wgIface common.IFaceMapper nativeFirewall firewall.Manager @@ -1091,11 +1090,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot return true } -// SetNetwork of the wireguard interface to which filtering applied -func (m *Manager) SetNetwork(network *net.IPNet) { - m.wgNetwork = network -} - // AddUDPPacketHook calls hook when UDP packet from given direction matched // // Hook function returns flag which indicates should be the matched package dropped or not diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index beb5b9336..c03e60640 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) { require.NoError(b, manager.Close(nil)) }) - manager.wgNetwork = &net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - } - // Apply scenario-specific setup sc.setupFunc(manager) @@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) { require.NoError(b, manager.Close(nil)) }) - manager.wgNetwork = &net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - } - // Pre-populate connection table srcIPs := generateRandomIPs(count) dstIPs := generateRandomIPs(count) @@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { require.NoError(b, manager.Close(nil)) }) - manager.wgNetwork = &net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - } - srcIP := generateRandomIPs(1)[0] dstIP := generateRandomIPs(1)[0] outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) @@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { proto: layers.IPProtocolTCP, state: "new", setupFunc: func(m *Manager) { - m.wgNetwork = &net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - } b.Setenv("NB_DISABLE_CONNTRACK", "1") }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { @@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { proto: layers.IPProtocolTCP, state: "established", setupFunc: func(m *Manager) { - m.wgNetwork = &net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - } b.Setenv("NB_DISABLE_CONNTRACK", "1") }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { @@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { proto: layers.IPProtocolUDP, state: "new", setupFunc: func(m *Manager) { - m.wgNetwork = &net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - } b.Setenv("NB_DISABLE_CONNTRACK", "1") }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { @@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { proto: layers.IPProtocolUDP, state: "established", setupFunc: func(m *Manager) { - m.wgNetwork = &net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - } b.Setenv("NB_DISABLE_CONNTRACK", "1") }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { @@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { proto: layers.IPProtocolTCP, state: "new", setupFunc: func(m *Manager) { - m.wgNetwork = &net.IPNet{ - IP: net.ParseIP("0.0.0.0"), - Mask: net.CIDRMask(0, 32), - } require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { @@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { proto: layers.IPProtocolTCP, state: "established", setupFunc: func(m *Manager) { - m.wgNetwork = &net.IPNet{ - IP: net.ParseIP("0.0.0.0"), - Mask: net.CIDRMask(0, 32), - } require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { @@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { proto: layers.IPProtocolTCP, state: "post_handshake", setupFunc: func(m *Manager) { - m.wgNetwork = &net.IPNet{ - IP: net.ParseIP("0.0.0.0"), - Mask: net.CIDRMask(0, 32), - } require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { @@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { proto: layers.IPProtocolUDP, state: "new", setupFunc: func(m *Manager) { - m.wgNetwork = &net.IPNet{ - IP: net.ParseIP("0.0.0.0"), - Mask: net.CIDRMask(0, 32), - } require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { @@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { proto: layers.IPProtocolUDP, state: "established", setupFunc: func(m *Manager) { - m.wgNetwork = &net.IPNet{ - IP: net.ParseIP("0.0.0.0"), - Mask: net.CIDRMask(0, 32), - } require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) }, genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { @@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) { require.NoError(b, manager.Close(nil)) }) - manager.SetNetwork(&net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - }) - // Setup initial state based on scenario if sc.rules { // Single rule to allow all return traffic from port 80 @@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) { require.NoError(b, manager.Close(nil)) }) - manager.SetNetwork(&net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - }) - // Setup initial state based on scenario if sc.rules { // Single rule to allow all return traffic from port 80 @@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { require.NoError(b, manager.Close(nil)) }) - manager.SetNetwork(&net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - }) - // Setup initial state based on scenario if sc.rules { _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") @@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { require.NoError(b, manager.Close(nil)) }) - manager.SetNetwork(&net.IPNet{ - IP: net.ParseIP("100.64.0.0"), - Mask: net.CIDRMask(10, 32), - }) - if sc.rules { _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") require.NoError(b, err) @@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) { } for _, r := range rules { - _, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept) + dst := fw.Network{Prefix: r.dest} + _, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept) if err != nil { b.Fatal(err) } diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index 04a398d1f..318f86a87 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -19,12 +19,8 @@ import ( ) func TestPeerACLFiltering(t *testing.T) { - localIP := net.ParseIP("100.10.0.100") - wgNet := &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - } - + localIP := netip.MustParseAddr("100.10.0.100") + wgNet := netip.MustParsePrefix("100.10.0.0/16") ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, AddressFunc: func() wgaddr.Address { @@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) { require.NoError(t, manager.Close(nil)) }) - manager.wgNetwork = wgNet - err = manager.UpdateLocalIPs() require.NoError(t, err) @@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { dev := mocks.NewMockDevice(ctrl) dev.EXPECT().MTU().Return(1500, nil).AnyTimes() - localIP, wgNet, err := net.ParseCIDR(network) - require.NoError(tb, err) + wgNet := netip.MustParsePrefix(network) ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ - IP: localIP, + IP: wgNet.Addr(), Network: wgNet, } }, @@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ - IP: net.ParseIP("100.10.0.100"), - Network: &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - }, + IP: netip.MustParseAddr("100.10.0.100"), + Network: netip.MustParsePrefix("100.10.0.0/16"), } }, } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 24a6a2c40..88de1ddcd 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ - IP: net.ParseIP("100.10.0.100"), - Network: &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - }, + IP: netip.MustParseAddr("100.10.0.100"), + Network: netip.MustParsePrefix("100.10.0.0/16"), } }, } @@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) { t.Errorf("failed to create Manager: %v", err) return } - m.wgNetwork = &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - } ip := net.ParseIP("0.0.0.0") proto := fw.ProtocolUDP @@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) { }, false, flowLogger) require.NoError(t, err) - manager.wgNetwork = &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - } manager.udpTracker.Close() manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) defer func() { @@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { }, false, flowLogger) require.NoError(t, err) - manager.wgNetwork = &net.IPNet{ - IP: net.ParseIP("100.10.0.0"), - Mask: net.CIDRMask(16, 32), - } - manager.udpTracker.Close() // Close the existing tracker manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger) manager.decoders = sync.Pool{ diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go index 9fed02bb7..5cc634955 100644 --- a/client/iface/bind/udp_mux_universal.go +++ b/client/iface/bind/udp_mux_universal.go @@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error { return nil } - if u.address.Network.Contains(a.AsSlice()) { + if u.address.Network.Contains(a) { log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) } diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index c9b7e2448..5a1a0e96a 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -1,7 +1,6 @@ package device import ( - "net" "net/netip" "sync" @@ -24,9 +23,6 @@ type PacketFilter interface { // RemovePacketHook removes hook by ID RemovePacketHook(hookID string) error - - // SetNetwork of the wireguard interface to which filtering applied - SetNetwork(*net.IPNet) } // FilteredDevice to override Read or Write of packets diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index d3c92235e..d2f2c87a1 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) { log.Info("create nbnetstack tun interface") // TODO: get from service listener runtime IP - dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1) + dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1) + if err != nil { + return nil, fmt.Errorf("last ip: %w", err) + } + log.Debugf("netstack using address: %s", t.address.IP) t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) log.Debugf("netstack using dns address: %s", dnsAddr) diff --git a/client/iface/device/wg_link_freebsd.go b/client/iface/device/wg_link_freebsd.go index 9067790e4..1b06e0e15 100644 --- a/client/iface/device/wg_link_freebsd.go +++ b/client/iface/device/wg_link_freebsd.go @@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error { } ip := address.IP.String() - mask := "0x" + address.Network.Mask.String() + + // Convert prefix length to hex netmask + prefixLen := address.Network.Bits() + if !address.IP.Is4() { + return fmt.Errorf("IPv6 not supported for interface assignment") + } + + maskBits := uint32(0xffffffff) << (32 - prefixLen) + mask := fmt.Sprintf("0x%08x", maskBits) log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) diff --git a/client/iface/iface.go b/client/iface/iface.go index c78a252da..1f659af29 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error { } w.filter = filter - w.filter.SetNetwork(w.tun.WgAddress().Network) w.tun.FilteredDevice().SetFilter(filter) return nil diff --git a/client/iface/mocks/filter.go b/client/iface/mocks/filter.go index faac55d68..8cd2a1231 100644 --- a/client/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -5,7 +5,6 @@ package mocks import ( - net "net" "net/netip" reflect "reflect" @@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) } - -// SetNetwork mocks base method. -func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetNetwork", arg0) -} - -// SetNetwork indicates an expected call of SetNetwork. -func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0) -} diff --git a/client/iface/netstack/tun.go b/client/iface/netstack/tun.go index a271a1954..aec9d4faa 100644 --- a/client/iface/netstack/tun.go +++ b/client/iface/netstack/tun.go @@ -1,8 +1,6 @@ package netstack import ( - "fmt" - "net" "net/netip" "os" "strconv" @@ -15,8 +13,8 @@ import ( const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" type NetStackTun struct { //nolint:revive - address net.IP - dnsAddress net.IP + address netip.Addr + dnsAddress netip.Addr mtu int listenAddress string @@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive tundev tun.Device } -func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun { +func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun { return &NetStackTun{ address: address, dnsAddress: dnsAddress, @@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu } func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { - addr, ok := netip.AddrFromSlice(t.address) - if !ok { - return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address) - } - - dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress) - if !ok { - return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress) - } - nsTunDev, tunNet, err := netstack.CreateNetTUN( - []netip.Addr{addr.Unmap()}, - []netip.Addr{dnsAddr.Unmap()}, + []netip.Addr{t.address}, + []netip.Addr{t.dnsAddress}, t.mtu) if err != nil { return nil, nil, err diff --git a/client/iface/wgaddr/address.go b/client/iface/wgaddr/address.go index e5079258c..078f8be95 100644 --- a/client/iface/wgaddr/address.go +++ b/client/iface/wgaddr/address.go @@ -2,28 +2,27 @@ package wgaddr import ( "fmt" - "net" + "net/netip" ) // Address WireGuard parsed address type Address struct { - IP net.IP - Network *net.IPNet + IP netip.Addr + Network netip.Prefix } // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address func ParseWGAddress(address string) (Address, error) { - ip, network, err := net.ParseCIDR(address) + prefix, err := netip.ParsePrefix(address) if err != nil { return Address{}, err } return Address{ - IP: ip, - Network: network, + IP: prefix.Addr().Unmap(), + Network: prefix.Masked(), }, nil } func (addr Address) String() string { - maskSize, _ := addr.Network.Mask.Size() - return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) + return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits()) } diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 532d70a24..16620033e 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -1,7 +1,7 @@ package acl import ( - "net" + "net/netip" "testing" "github.com/golang/mock/gomock" @@ -43,12 +43,11 @@ func TestDefaultManager(t *testing.T) { ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().SetFilter(gomock.Any()) - ip, network, err := net.ParseCIDR("172.0.0.1/32") - require.NoError(t, err) + network := netip.MustParsePrefix("172.0.0.1/32") ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Address().Return(wgaddr.Address{ - IP: ip, + IP: network.Addr(), Network: network, }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() @@ -162,12 +161,11 @@ func TestDefaultManagerStateless(t *testing.T) { ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().SetFilter(gomock.Any()) - ip, network, err := net.ParseCIDR("172.0.0.1/32") - require.NoError(t, err) + network := netip.MustParsePrefix("172.0.0.1/32") ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Address().Return(wgaddr.Address{ - IP: ip, + IP: network.Addr(), Network: network, }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() @@ -372,12 +370,11 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().SetFilter(gomock.Any()) - ip, network, err := net.ParseCIDR("172.0.0.1/32") - require.NoError(t, err) + network := netip.MustParsePrefix("172.0.0.1/32") ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Address().Return(wgaddr.Address{ - IP: ip, + IP: network.Addr(), Network: network, }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() diff --git a/client/internal/dns.go b/client/internal/dns.go index 8a73f50f2..5e604bec5 100644 --- a/client/internal/dns.go +++ b/client/internal/dns.go @@ -2,7 +2,7 @@ package internal import ( "fmt" - "net" + "net/netip" "slices" "strings" @@ -12,13 +12,14 @@ import ( nbdns "github.com/netbirdio/netbird/dns" ) -func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) { - ip := net.ParseIP(aRecord.RData) - if ip == nil || ip.To4() == nil { +func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) { + ip, err := netip.ParseAddr(aRecord.RData) + if err != nil { + log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err) return nbdns.SimpleRecord{}, false } - if !ipNet.Contains(ip) { + if !prefix.Contains(ip) { return nbdns.SimpleRecord{}, false } @@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple } // generateReverseZoneName creates the reverse DNS zone name for a given network -func generateReverseZoneName(ipNet *net.IPNet) (string, error) { - networkIP := ipNet.IP.Mask(ipNet.Mask) - maskOnes, _ := ipNet.Mask.Size() +func generateReverseZoneName(network netip.Prefix) (string, error) { + networkIP := network.Masked().Addr() + + if !networkIP.Is4() { + return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP) + } // round up to nearest byte - octetsToUse := (maskOnes + 7) / 8 + octetsToUse := (network.Bits() + 7) / 8 octets := strings.Split(networkIP.String(), ".") if octetsToUse > len(octets) { - return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes) + return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits()) } reverseOctets := make([]string, octetsToUse) @@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool { } // collectPTRRecords gathers all PTR records for the given network from A records -func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord { +func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, zone := range config.CustomZones { @@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec continue } - if ptrRecord, ok := createPTRRecord(record, ipNet); ok { + if ptrRecord, ok := createPTRRecord(record, prefix); ok { records = append(records, ptrRecord) } } @@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec } // addReverseZone adds a reverse DNS zone to the configuration for the given network -func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) { - zoneName, err := generateReverseZoneName(ipNet) +func addReverseZone(config *nbdns.Config, network netip.Prefix) { + zoneName, err := generateReverseZoneName(network) if err != nil { log.Warn(err) return @@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) { return } - records := collectPTRRecords(config, ipNet) + records := collectPTRRecords(config, network) reverseZone := nbdns.CustomZone{ Domain: zoneName, diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 1c7c9b117..e55b27910 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -46,10 +46,9 @@ func (w *mocWGIface) Name() string { } func (w *mocWGIface) Address() wgaddr.Address { - ip, network, _ := net.ParseCIDR("100.66.100.0/24") return wgaddr.Address{ - IP: ip, - Network: network, + IP: netip.MustParseAddr("100.66.100.1"), + Network: netip.MustParsePrefix("100.66.100.0/24"), } } @@ -464,17 +463,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - _, ipNet, err := net.ParseCIDR("100.66.100.1/32") - if err != nil { - t.Errorf("parse CIDR: %v", err) - return - } - packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes() packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any()) - packetfilter.EXPECT().SetNetwork(ipNet) if err := wgIface.SetFilter(packetfilter); err != nil { t.Errorf("set packet filter: %v", err) diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 34c563757..226202cf7 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -24,11 +24,15 @@ type ServiceViaMemory struct { } func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { + lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1) + if err != nil { + log.Errorf("get last ip from network: %v", err) + } s := &ServiceViaMemory{ wgInterface: wgIface, dnsMux: dns.NewServeMux(), - runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(), + runtimeIP: lastIP.String(), runtimePort: defaultPort, } return s @@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) { } firstLayerDecoder := layers.LayerTypeIPv4 - if s.wgInterface.Address().Network.IP.To4() == nil { + if s.wgInterface.Address().IP.Is6() { firstLayerDecoder = layers.LayerTypeIPv6 } diff --git a/client/internal/dns/service_memory_test.go b/client/internal/dns/service_memory_test.go deleted file mode 100644 index 244adfaef..000000000 --- a/client/internal/dns/service_memory_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package dns - -import ( - "net" - "testing" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -func TestGetLastIPFromNetwork(t *testing.T) { - tests := []struct { - addr string - ip string - }{ - {"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"}, - {"192.168.0.0/30", "192.168.0.2"}, - {"192.168.0.0/16", "192.168.255.254"}, - {"192.168.0.0/24", "192.168.0.254"}, - } - - for _, tt := range tests { - _, ipnet, err := net.ParseCIDR(tt.addr) - if err != nil { - t.Errorf("Error parsing CIDR: %v", err) - return - } - - lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String() - if lastIP != tt.ip { - t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP) - } - } -} diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 06ffcba11..52d2ba58b 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -3,6 +3,7 @@ package dns import ( "context" "net" + "net/netip" "syscall" "time" @@ -23,8 +24,8 @@ type upstreamResolver struct { func newUpstreamResolver( ctx context.Context, _ string, - _ net.IP, - _ *net.IPNet, + _ netip.Addr, + _ netip.Prefix, statusRecorder *peer.Status, hostsDNSHolder *hostsDNSHolder, domain string, diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 9bb5feab0..1bc06a7c1 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -4,7 +4,7 @@ package dns import ( "context" - "net" + "net/netip" "time" "github.com/miekg/dns" @@ -19,8 +19,8 @@ type upstreamResolver struct { func newUpstreamResolver( ctx context.Context, _ string, - _ net.IP, - _ *net.IPNet, + _ netip.Addr, + _ netip.Prefix, statusRecorder *peer.Status, _ *hostsDNSHolder, domain string, diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index ca5b31132..648cab176 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "net" + "net/netip" "syscall" "time" @@ -18,16 +19,16 @@ import ( type upstreamResolverIOS struct { *upstreamResolverBase - lIP net.IP - lNet *net.IPNet + lIP netip.Addr + lNet netip.Prefix interfaceName string } func newUpstreamResolver( ctx context.Context, interfaceName string, - ip net.IP, - net *net.IPNet, + ip netip.Addr, + net netip.Prefix, statusRecorder *peer.Status, _ *hostsDNSHolder, domain string, @@ -58,8 +59,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } client.DialTimeout = timeout - upstreamIP := net.ParseIP(upstreamHost) - if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { + upstreamIP, err := netip.ParseAddr(upstreamHost) + if err != nil { + log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err) + } + if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() { log.Debugf("using private client to query upstream: %s", upstream) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) if err != nil { @@ -73,7 +77,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // This method is needed for iOS -func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { +func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { index, err := getInterfaceIndex(interfaceName) if err != nil { log.Debugf("unable to get interface index for %s: %s", interfaceName, err) @@ -82,7 +86,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration dialer := &net.Dialer{ LocalAddr: &net.UDPAddr{ - IP: ip, + IP: ip.AsSlice(), Port: 0, // Let the OS pick a free port }, Timeout: dialTimeout, diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 13bc91a37..e440995d9 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -2,7 +2,7 @@ package dns import ( "context" - "net" + "net/netip" "strings" "testing" "time" @@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) - resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".") + resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".") resolver.upstreamServers = testCase.InputServers resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { diff --git a/client/internal/engine.go b/client/internal/engine.go index d015c1d6c..0dec799bf 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1008,7 +1008,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { - log.Errorf("failed to update clientRoutes, err: %v", err) + log.Errorf("failed to update routes: %v", err) } if e.acl != nil { @@ -1104,7 +1104,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { convertedRoute := &route.Route{ ID: route.ID(protoRoute.ID), - Network: prefix, + Network: prefix.Masked(), Domains: domain.FromPunycodeList(protoRoute.Domains), NetID: route.NetID(protoRoute.NetID), NetworkType: route.NetworkType(protoRoute.NetworkType), @@ -1138,7 +1138,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE return entries } -func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { +func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { dnsUpdate := nbdns.Config{ ServiceEnable: protoDNSConfig.GetServiceEnable(), CustomZones: make([]nbdns.CustomZone, 0), @@ -1790,9 +1790,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { } // GetWgAddr returns the wireguard address -func (e *Engine) GetWgAddr() net.IP { +func (e *Engine) GetWgAddr() netip.Addr { if e.wgInterface == nil { - return nil + return netip.Addr{} } return e.wgInterface.Address().IP } @@ -1861,12 +1861,7 @@ func (e *Engine) Address() (netip.Addr, error) { return netip.Addr{}, errors.New("wireguard interface not initialized") } - addr := e.wgInterface.Address() - ip, ok := netip.AddrFromSlice(addr.IP) - if !ok { - return netip.Addr{}, errors.New("failed to convert address to netip.Addr") - } - return ip.Unmap(), nil + return e.wgInterface.Address().IP, nil } func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 422059bd8..82c1ba0e2 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -371,11 +371,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { }, AddressFunc: func() wgaddr.Address { return wgaddr.Address{ - IP: net.ParseIP("10.20.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("10.20.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, + IP: netip.MustParseAddr("10.20.0.1"), + Network: netip.MustParsePrefix("10.20.0.0/24"), } }, UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index f8440b913..d01adf135 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool { // fallback if mark rules are not in place wgnet := c.iface.Address().Network - return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice()) + return wgnet.Contains(srcIP) || wgnet.Contains(dstIP) } // mapRxPackets maps packet counts to RX based on flow direction @@ -293,17 +293,15 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes // fallback if marks are not set wgaddr := c.iface.Address().IP wgnetwork := c.iface.Address().Network - src, dst := srcIP.AsSlice(), dstIP.AsSlice() - switch { - case wgaddr.Equal(src): + case wgaddr == srcIP: return nftypes.Egress - case wgaddr.Equal(dst): + case wgaddr == dstIP: return nftypes.Ingress - case wgnetwork.Contains(src): + case wgnetwork.Contains(srcIP): // netbird network -> resource network return nftypes.Ingress - case wgnetwork.Contains(dst): + case wgnetwork.Contains(dstIP): // resource network -> netbird network return nftypes.Egress } diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index a3bd091b6..e28fdf2f4 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -2,7 +2,7 @@ package logger import ( "context" - "net" + "net/netip" "sync" "sync/atomic" "time" @@ -23,17 +23,16 @@ type Logger struct { rcvChan atomic.Pointer[rcvChan] cancel context.CancelFunc statusRecorder *peer.Status - wgIfaceIPNet net.IPNet + wgIfaceNet netip.Prefix dnsCollection atomic.Bool exitNodeCollection atomic.Bool Store types.Store } -func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger { - +func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger { return &Logger{ statusRecorder: statusRecorder, - wgIfaceIPNet: wgIfaceIPNet, + wgIfaceNet: wgIfaceIPNet, Store: store.NewMemoryStore(), } } @@ -89,11 +88,11 @@ func (l *Logger) startReceiver() { var isSrcExitNode bool var isDestExitNode bool - if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) { + if !l.wgIfaceNet.Contains(event.SourceIP) { event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP) } - if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) { + if !l.wgIfaceNet.Contains(event.DestIP) { event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP) } diff --git a/client/internal/netflow/logger/logger_test.go b/client/internal/netflow/logger/logger_test.go index 06e10c36c..1144544d8 100644 --- a/client/internal/netflow/logger/logger_test.go +++ b/client/internal/netflow/logger/logger_test.go @@ -1,7 +1,7 @@ package logger_test import ( - "net" + "net/netip" "testing" "time" @@ -12,7 +12,7 @@ import ( ) func TestStore(t *testing.T) { - logger := logger.New(nil, net.IPNet{}) + logger := logger.New(nil, netip.Prefix{}) logger.Enable() event := types.EventFields{ diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index bf80e5a9f..e3b188468 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "net" + "net/netip" "runtime" "sync" "time" @@ -34,11 +34,11 @@ type Manager struct { // NewManager creates a new netflow manager func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager { - var ipNet net.IPNet + var prefix netip.Prefix if iface != nil { - ipNet = *iface.Address().Network + prefix = iface.Address().Network } - flowLogger := logger.New(statusRecorder, ipNet) + flowLogger := logger.New(statusRecorder, prefix) var ct nftypes.ConnTracker if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { diff --git a/client/internal/netflow/manager_test.go b/client/internal/netflow/manager_test.go index bf7e05f8e..0b5eb3be6 100644 --- a/client/internal/netflow/manager_test.go +++ b/client/internal/netflow/manager_test.go @@ -1,7 +1,7 @@ package netflow import ( - "net" + "net/netip" "testing" "time" @@ -33,10 +33,7 @@ func (m *mockIFaceMapper) IsUserspaceBind() bool { func TestManager_Update(t *testing.T) { mockIFace := &mockIFaceMapper{ address: wgaddr.Address{ - Network: &net.IPNet{ - IP: net.ParseIP("192.168.1.1"), - Mask: net.CIDRMask(24, 32), - }, + Network: netip.MustParsePrefix("192.168.1.1/32"), }, isUserspaceBind: true, } @@ -102,10 +99,7 @@ func TestManager_Update(t *testing.T) { func TestManager_Update_TokenPreservation(t *testing.T) { mockIFace := &mockIFaceMapper{ address: wgaddr.Address{ - Network: &net.IPNet{ - IP: net.ParseIP("192.168.1.1"), - Mask: net.CIDRMask(24, 32), - }, + Network: netip.MustParsePrefix("192.168.1.1/32"), }, isUserspaceBind: true, } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 6d51c88c0..78d5e3b30 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -264,7 +264,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { continue } - prefix := netip.PrefixFrom(ip, ip.BitLen()) + prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen()) newPrefixes = append(newPrefixes, prefix) } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index afb74c23e..8dbbb5f77 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -333,11 +333,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) + var merr *multierror.Error if !m.disableClientRoutes { filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) if err := m.updateSystemRoutes(filteredClientRoutes); err != nil { - log.Errorf("Failed to update system routes: %v", err) + merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err)) } m.updateClientNetworks(updateSerial, filteredClientRoutes) @@ -346,14 +347,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro m.clientRoutes = newClientRoutesIDMap if m.serverRouter == nil { - return nil + return nberrors.FormatErrorOrNil(merr) } if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil { - return fmt.Errorf("update routes: %w", err) + merr = multierror.Append(merr, fmt.Errorf("update server routes: %w", err)) } - return nil + return nberrors.FormatErrorOrNil(merr) } // SetRouteChangeListener set RouteListener for route change Notifier diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 680bd813f..a46ae080e 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -44,7 +44,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "a", NetID: "routeA", Peer: remotePeerKey1, - Network: netip.MustParsePrefix("100.64.251.250/30"), + Network: netip.MustParsePrefix("100.64.251.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -71,7 +71,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "a", NetID: "routeA", Peer: localPeerKey, - Network: netip.MustParsePrefix("100.64.252.250/30"), + Network: netip.MustParsePrefix("100.64.252.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -99,7 +99,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "a", NetID: "routeA", Peer: localPeerKey, - Network: netip.MustParsePrefix("100.64.30.250/30"), + Network: netip.MustParsePrefix("100.64.30.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -127,7 +127,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "a", NetID: "routeA", Peer: localPeerKey, - Network: netip.MustParsePrefix("100.64.30.250/30"), + Network: netip.MustParsePrefix("100.64.30.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -211,7 +211,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "a", NetID: "routeA", Peer: remotePeerKey1, - Network: netip.MustParsePrefix("100.64.251.250/30"), + Network: netip.MustParsePrefix("100.64.251.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -233,7 +233,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "a", NetID: "routeA", Peer: remotePeerKey1, - Network: netip.MustParsePrefix("100.64.251.250/30"), + Network: netip.MustParsePrefix("100.64.251.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -250,7 +250,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "a", NetID: "routeA", Peer: remotePeerKey1, - Network: netip.MustParsePrefix("100.64.251.250/30"), + Network: netip.MustParsePrefix("100.64.251.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -272,7 +272,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "a", NetID: "routeA", Peer: remotePeerKey1, - Network: netip.MustParsePrefix("100.64.251.250/30"), + Network: netip.MustParsePrefix("100.64.251.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -282,7 +282,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "b", NetID: "routeA", Peer: remotePeerKey2, - Network: netip.MustParsePrefix("100.64.251.250/30"), + Network: netip.MustParsePrefix("100.64.251.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -299,7 +299,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "a", NetID: "routeA", Peer: remotePeerKey1, - Network: netip.MustParsePrefix("100.64.251.250/30"), + Network: netip.MustParsePrefix("100.64.251.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -327,7 +327,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "a", NetID: "routeA", Peer: localPeerKey, - Network: netip.MustParsePrefix("100.64.251.250/30"), + Network: netip.MustParsePrefix("100.64.251.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -356,7 +356,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "l1", NetID: "routeA", Peer: localPeerKey, - Network: netip.MustParsePrefix("100.64.251.250/30"), + Network: netip.MustParsePrefix("100.64.251.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -376,7 +376,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ID: "r1", NetID: "routeA", Peer: remotePeerKey1, - Network: netip.MustParsePrefix("100.64.251.250/30"), + Network: netip.MustParsePrefix("100.64.251.248/30"), NetworkType: route.IPv4Network, Metric: 9999, Masquerade: false, @@ -440,11 +440,11 @@ func TestManagerUpdateRoutes(t *testing.T) { } if len(testCase.inputInitRoutes) > 0 { - _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false) + err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false) require.NoError(t, err, "should update routes with init routes") } - _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false) + err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false) require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go index ea63f02fc..f96a57f37 100644 --- a/client/internal/routemanager/sysctl/sysctl_linux.go +++ b/client/internal/routemanager/sysctl/sysctl_linux.go @@ -13,7 +13,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/internal/routemanager/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) const ( @@ -22,8 +22,13 @@ const ( srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark" ) +type iface interface { + Address() wgaddr.Address + Name() string +} + // Setup configures sysctl settings for RP filtering and source validation. -func Setup(wgIface iface.WGIface) (map[string]int, error) { +func Setup(wgIface iface) (map[string]int, error) { keys := map[string]int{} var result *multierror.Error diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index fd511fc20..261567dc3 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -6,9 +6,10 @@ import ( "net/netip" "sync" - "github.com/netbirdio/netbird/client/internal/routemanager/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/client/internal/routemanager/vars" ) type Nexthop struct { @@ -30,11 +31,16 @@ func (n Nexthop) String() string { return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name) } +type wgIface interface { + Address() wgaddr.Address + Name() string +} + type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type SysOps struct { refCounter *ExclusionCounter - wgInterface iface.WGIface + wgInterface wgIface // prefixes is tracking all the current added prefixes im memory // (this is used in iOS as all route updates require a full table update) //nolint @@ -45,9 +51,27 @@ type SysOps struct { notifier *notifier.Notifier } -func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps { +func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, } } + +func (r *SysOps) validateRoute(prefix netip.Prefix) error { + addr := prefix.Addr() + + switch { + case + !addr.IsValid(), + addr.IsLoopback(), + addr.IsLinkLocalUnicast(), + addr.IsLinkLocalMulticast(), + addr.IsInterfaceLocalMulticast(), + addr.IsMulticast(), + addr.IsUnspecified() && prefix.Bits() != 0, + r.wgInterface.Address().Network.Contains(addr): + return vars.ErrRouteNotAllowed + } + return nil +} diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index a83d7f1de..0d892c162 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -8,6 +8,8 @@ import ( "net/netip" "os/exec" "regexp" + "runtime" + "strings" "sync" "testing" @@ -33,7 +35,12 @@ func init() { func TestConcurrentRoutes(t *testing.T) { baseIP := netip.MustParseAddr("192.0.2.0") - intf := &net.Interface{Name: "lo0"} + + var intf *net.Interface + var nexthop Nexthop + + _, intf = setupDummyInterface(t) + nexthop = Nexthop{netip.Addr{}, intf} r := NewSysOps(nil, nil) @@ -43,7 +50,7 @@ func TestConcurrentRoutes(t *testing.T) { go func(ip netip.Addr) { defer wg.Done() prefix := netip.PrefixFrom(ip, 32) - if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil { + if err := r.addToRouteTable(prefix, nexthop); err != nil { t.Errorf("Failed to add route for %s: %v", prefix, err) } }(baseIP) @@ -59,7 +66,7 @@ func TestConcurrentRoutes(t *testing.T) { go func(ip netip.Addr) { defer wg.Done() prefix := netip.PrefixFrom(ip, 32) - if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil { + if err := r.removeFromRouteTable(prefix, nexthop); err != nil { t.Errorf("Failed to remove route for %s: %v", prefix, err) } }(baseIP) @@ -119,18 +126,39 @@ func TestBits(t *testing.T) { func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { t.Helper() - err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() - require.NoError(t, err, "Failed to create loopback alias") + if runtime.GOOS == "darwin" { + err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() + require.NoError(t, err, "Failed to create loopback alias") + + t.Cleanup(func() { + err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() + assert.NoError(t, err, "Failed to remove loopback alias") + }) + + return intf + } + + prefix, err := netip.ParsePrefix(ipAddressCIDR) + require.NoError(t, err, "Failed to parse prefix") + + netIntf, err := net.InterfaceByName(intf) + require.NoError(t, err, "Failed to get interface by name") + + nexthop := Nexthop{netip.Addr{}, netIntf} + + r := NewSysOps(nil, nil) + err = r.addToRouteTable(prefix, nexthop) + require.NoError(t, err, "Failed to add route to table") t.Cleanup(func() { - err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() - assert.NoError(t, err, "Failed to remove loopback alias") + err := r.removeFromRouteTable(prefix, nexthop) + assert.NoError(t, err, "Failed to remove route from table") }) - return "lo0" + return intf } -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { +func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) { t.Helper() var originalNexthop net.IP @@ -176,12 +204,40 @@ func fetchOriginalGateway() (net.IP, error) { return net.ParseIP(matches[1]), nil } +// setupDummyInterface creates a dummy tun interface for FreeBSD route testing +func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) { + t.Helper() + + if runtime.GOOS == "darwin" { + return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"} + } + + output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput() + require.NoError(t, err, "Failed to create tun interface: %s", string(output)) + + tunName := strings.TrimSpace(string(output)) + + output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput() + require.NoError(t, err, "Failed to configure tun interface: %s", string(output)) + + intf, err := net.InterfaceByName(tunName) + require.NoError(t, err, "Failed to get interface by name") + + t.Cleanup(func() { + if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil { + t.Logf("Failed to destroy tun interface %s: %v", tunName, err) + } + }) + + return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf +} + func setupDummyInterfacesAndRoutes(t *testing.T) { t.Helper() defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) + addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy) otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) + addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy) } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index eaef01815..d223a27b2 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -17,7 +17,6 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" @@ -106,59 +105,15 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { return nil } -// TODO: fix: for default our wg address now appears as the default gw -func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - addr := netip.IPv4Unspecified() - if prefix.Addr().Is6() { - addr = netip.IPv6Unspecified() - } - - nexthop, err := GetNextHop(addr) - if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - if !prefix.Contains(nexthop.IP) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32) - if nexthop.IP.Is6() { - gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128) - } - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - nexthop, err = GetNextHop(nexthop.IP) - if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP) - return r.addToRouteTable(gatewayPrefix, nexthop) -} - // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // If the next hop or interface is pointing to the VPN interface, it will return the initial values. -func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) { - addr := prefix.Addr() - switch { - case addr.IsLoopback(), - addr.IsLinkLocalUnicast(), - addr.IsLinkLocalMulticast(), - addr.IsInterfaceLocalMulticast(), - addr.IsUnspecified(), - addr.IsMulticast(): +func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, initialNextHop Nexthop) (Nexthop, error) { + if err := r.validateRoute(prefix); err != nil { + return Nexthop{}, err + } + addr := prefix.Addr() + if addr.IsUnspecified() { return Nexthop{}, vars.ErrRouteNotAllowed } @@ -179,10 +134,7 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface Intf: nexthop.Intf, } - vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) - if !ok { - return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr") - } + vpnAddr := vpnIntf.Address().IP // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { @@ -271,32 +223,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er return nil } - return r.addNonExistingRoute(prefix, intf) -} - -// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table -func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}) + return r.addToRouteTable(prefix, nextHop) } // genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, @@ -408,12 +335,8 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) { log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) if gateway == nil { - if runtime.GOOS == "freebsd" { - return Nexthop{Intf: intf}, nil - } - if preferredSrc == nil { - return Nexthop{}, vars.ErrRouteNotFound + return Nexthop{Intf: intf}, nil } log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc) @@ -457,32 +380,6 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) { return addr.Unmap(), nil } -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := GetRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := GetRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} - // IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { localRoutes, err := hasSeparateRouting() diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 5b7b13f97..2a57e6044 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -3,23 +3,25 @@ package systemops import ( - "bytes" "context" + "errors" "fmt" "net" "net/netip" - "os" + "os/exec" "runtime" + "strconv" "strings" + "syscall" "testing" "github.com/pion/transport/v3/stdnet" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/vars" ) type dialer interface { @@ -27,105 +29,370 @@ type dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } -func TestAddRemoveRoutes(t *testing.T) { +func TestAddVPNRoute(t *testing.T) { testCases := []struct { - name string - prefix netip.Prefix - shouldRouteToWireguard bool - shouldBeRemoved bool + name string + prefix netip.Prefix + expectError bool }{ { - name: "Should Add And Remove Route 100.66.120.0/24", - prefix: netip.MustParsePrefix("100.66.120.0/24"), - shouldRouteToWireguard: true, - shouldBeRemoved: true, + name: "IPv4 - Private network route", + prefix: netip.MustParsePrefix("10.10.100.0/24"), }, { - name: "Should Not Add Or Remove Route 127.0.0.1/32", - prefix: netip.MustParsePrefix("127.0.0.1/32"), - shouldRouteToWireguard: false, - shouldBeRemoved: false, + name: "IPv4 Single host", + prefix: netip.MustParsePrefix("10.111.111.111/32"), + }, + { + name: "IPv4 RFC3927 test range", + prefix: netip.MustParsePrefix("198.51.100.0/24"), + }, + { + name: "IPv4 Default route", + prefix: netip.MustParsePrefix("0.0.0.0/0"), + }, + + { + name: "IPv6 Subnet", + prefix: netip.MustParsePrefix("fdb1:848a:7e16::/48"), + }, + { + name: "IPv6 Single host", + prefix: netip.MustParsePrefix("fdb1:848a:7e16:a::b/128"), + }, + { + name: "IPv6 Default route", + prefix: netip.MustParsePrefix("::/0"), + }, + + // IPv4 addresses that should be rejected (matches validateRoute logic) + { + name: "IPv4 Loopback", + prefix: netip.MustParsePrefix("127.0.0.1/32"), + expectError: true, + }, + { + name: "IPv4 Link-local unicast", + prefix: netip.MustParsePrefix("169.254.1.1/32"), + expectError: true, + }, + { + name: "IPv4 Link-local multicast", + prefix: netip.MustParsePrefix("224.0.0.251/32"), + expectError: true, + }, + { + name: "IPv4 Multicast", + prefix: netip.MustParsePrefix("239.255.255.250/32"), + expectError: true, + }, + { + name: "IPv4 Unspecified with prefix", + prefix: netip.MustParsePrefix("0.0.0.0/32"), + expectError: true, + }, + + // IPv6 addresses that should be rejected (matches validateRoute logic) + { + name: "IPv6 Loopback", + prefix: netip.MustParsePrefix("::1/128"), + expectError: true, + }, + { + name: "IPv6 Link-local unicast", + prefix: netip.MustParsePrefix("fe80::1/128"), + expectError: true, + }, + { + name: "IPv6 Link-local multicast", + prefix: netip.MustParsePrefix("ff02::1/128"), + expectError: true, + }, + { + name: "IPv6 Interface-local multicast", + prefix: netip.MustParsePrefix("ff01::1/128"), + expectError: true, + }, + { + name: "IPv6 Multicast", + prefix: netip.MustParsePrefix("ff00::1/128"), + expectError: true, + }, + { + name: "IPv6 Unspecified with prefix", + prefix: netip.MustParsePrefix("::/128"), + expectError: true, + }, + + { + name: "IPv4 WireGuard interface network overlap", + prefix: netip.MustParsePrefix("100.65.75.0/24"), + expectError: true, + }, + { + name: "IPv4 WireGuard interface network subnet", + prefix: netip.MustParsePrefix("100.65.75.0/32"), + expectError: true, }, } for n, testCase := range testCases { - // todo resolve test execution on freebsd - if runtime.GOOS == "freebsd" { - t.Skip("skipping ", testCase.name, " on freebsd") - } t.Run(testCase.name, func(t *testing.T) { t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") - peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() - if err != nil { - t.Fatal(err) - } - opts := iface.WGIFaceOpts{ - IFaceName: fmt.Sprintf("utun53%d", n), - Address: "100.65.75.2/24", - WGPrivKey: peerPrivateKey.String(), - MTU: iface.DefaultMTU, - TransportNet: newNet, - } - wgInterface, err := iface.NewWGIFace(opts) - require.NoError(t, err, "should create testing WGIface interface") - defer wgInterface.Close() - - err = wgInterface.Create() - require.NoError(t, err, "should create testing wireguard interface") + wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) r := NewSysOps(wgInterface, nil) - - _, _, err = r.SetupRouting(nil, nil) + _, _, err := r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, r.CleanupRouting(nil)) }) - index, err := net.InterfaceByName(wgInterface.Name()) - require.NoError(t, err, "InterfaceByName should not return err") - intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} + intf, err := net.InterfaceByName(wgInterface.Name()) + require.NoError(t, err) + // add the route err = r.AddVPNRoute(testCase.prefix, intf) - require.NoError(t, err, "genericAddVPNRoute should not return err") + if testCase.expectError { + assert.ErrorIs(t, err, vars.ErrRouteNotAllowed) + return + } - if testCase.shouldRouteToWireguard { - assertWGOutInterface(t, testCase.prefix, wgInterface, false) + // validate it's pointing to the WireGuard interface + require.NoError(t, err) + + nextHop := getNextHop(t, testCase.prefix.Addr()) + assert.Equal(t, wgInterface.Name(), nextHop.Intf.Name, "next hop interface should be WireGuard interface") + + // remove route again + err = r.RemoveVPNRoute(testCase.prefix, intf) + require.NoError(t, err) + + // validate it's gone + nextHop, err = GetNextHop(testCase.prefix.Addr()) + require.True(t, + errors.Is(err, vars.ErrRouteNotFound) || err == nil && nextHop.Intf != nil && nextHop.Intf.Name != wgInterface.Name(), + "err: %v, next hop: %v", err, nextHop) + }) + } +} + +func getNextHop(t *testing.T, addr netip.Addr) Nexthop { + t.Helper() + + if runtime.GOOS == "windows" || runtime.GOOS == "linux" { + nextHop, err := GetNextHop(addr) + + if runtime.GOOS == "windows" && errors.Is(err, vars.ErrRouteNotFound) && addr.Is6() { + // TODO: Fix this test. It doesn't return the route when running in a windows github runner, but it is + // present in the route table. + t.Skip("Skipping windows test") + } + + require.NoError(t, err) + require.NotNil(t, nextHop.Intf, "next hop interface should not be nil for %s", addr) + + return nextHop + } + // GetNextHop for bsd is buggy and returns the wrong interface for the default route. + + if addr.IsUnspecified() { + // On macOS, querying 0.0.0.0 returns the wrong interface + if addr.Is4() { + addr = netip.MustParseAddr("1.2.3.4") + } else { + addr = netip.MustParseAddr("2001:db8::1") + } + } + + cmd := exec.Command("route", "-n", "get", addr.String()) + if addr.Is6() { + cmd = exec.Command("route", "-n", "get", "-inet6", addr.String()) + } + + output, err := cmd.CombinedOutput() + t.Logf("route output: %s", output) + require.NoError(t, err, "%s failed") + + lines := strings.Split(string(output), "\n") + var intf string + var gateway string + + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "interface:") { + intf = strings.TrimSpace(strings.TrimPrefix(line, "interface:")) + } else if strings.HasPrefix(line, "gateway:") { + gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:")) + } + } + + require.NotEmpty(t, intf, "interface should be found in route output") + + iface, err := net.InterfaceByName(intf) + require.NoError(t, err, "interface %s should exist", intf) + + nexthop := Nexthop{Intf: iface} + + if gateway != "" && gateway != "link#"+strconv.Itoa(iface.Index) { + addr, err := netip.ParseAddr(gateway) + if err == nil { + nexthop.IP = addr + } + } + + return nexthop +} + +func TestAddRouteToNonVPNIntf(t *testing.T) { + testCases := []struct { + name string + prefix netip.Prefix + expectError bool + errorType error + }{ + { + name: "IPv4 RFC3927 test range", + prefix: netip.MustParsePrefix("198.51.100.0/24"), + }, + { + name: "IPv4 Single host", + prefix: netip.MustParsePrefix("8.8.8.8/32"), + }, + { + name: "IPv6 External network route", + prefix: netip.MustParsePrefix("2001:db8:1000::/48"), + }, + { + name: "IPv6 Single host", + prefix: netip.MustParsePrefix("2001:db8::1/128"), + }, + { + name: "IPv6 Subnet", + prefix: netip.MustParsePrefix("2a05:d014:1f8d::/48"), + }, + { + name: "IPv6 Single host", + prefix: netip.MustParsePrefix("2a05:d014:1f8d:7302:ebca:ec15:b24d:d07e/128"), + }, + + // Addresses that should be rejected + { + name: "IPv4 Loopback", + prefix: netip.MustParsePrefix("127.0.0.1/32"), + expectError: true, + errorType: vars.ErrRouteNotAllowed, + }, + { + name: "IPv4 Link-local unicast", + prefix: netip.MustParsePrefix("169.254.1.1/32"), + expectError: true, + errorType: vars.ErrRouteNotAllowed, + }, + { + name: "IPv4 Multicast", + prefix: netip.MustParsePrefix("239.255.255.250/32"), + expectError: true, + errorType: vars.ErrRouteNotAllowed, + }, + { + name: "IPv4 Unspecified", + prefix: netip.MustParsePrefix("0.0.0.0/0"), + expectError: true, + errorType: vars.ErrRouteNotAllowed, + }, + { + name: "IPv6 Loopback", + prefix: netip.MustParsePrefix("::1/128"), + expectError: true, + errorType: vars.ErrRouteNotAllowed, + }, + { + name: "IPv6 Link-local unicast", + prefix: netip.MustParsePrefix("fe80::1/128"), + expectError: true, + errorType: vars.ErrRouteNotAllowed, + }, + { + name: "IPv6 Multicast", + prefix: netip.MustParsePrefix("ff00::1/128"), + expectError: true, + errorType: vars.ErrRouteNotAllowed, + }, + { + name: "IPv6 Unspecified", + prefix: netip.MustParsePrefix("::/0"), + expectError: true, + errorType: vars.ErrRouteNotAllowed, + }, + { + name: "IPv4 WireGuard interface network overlap", + prefix: netip.MustParsePrefix("100.65.75.0/24"), + expectError: true, + errorType: vars.ErrRouteNotAllowed, + }, + } + + for n, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") + + wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) + + r := NewSysOps(wgInterface, nil) + _, _, err := r.SetupRouting(nil, nil) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, r.CleanupRouting(nil)) + }) + + initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) + require.NoError(t, err, "Should be able to get IPv4 default route") + t.Logf("Initial IPv4 next hop: %s", initialNextHopV4) + + initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified()) + if testCase.prefix.Addr().Is6() && + (errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) { + t.Skip("Skipping test as no ipv6 default route is available") + } + if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { + t.Fatalf("Failed to get IPv6 default route: %v", err) + } + + var initialNextHop Nexthop + if testCase.prefix.Addr().Is6() { + initialNextHop = initialNextHopV6 } else { - assertWGOutInterface(t, testCase.prefix, wgInterface, true) + initialNextHop = initialNextHopV4 } - exists, err := existsInRouteTable(testCase.prefix) - require.NoError(t, err, "existsInRouteTable should not return err") - if exists && testCase.shouldRouteToWireguard { - err = r.RemoveVPNRoute(testCase.prefix, intf) - require.NoError(t, err, "genericRemoveVPNRoute should not return err") - prefixNexthop, err := GetNextHop(testCase.prefix.Addr()) - require.NoError(t, err, "GetNextHop should not return err") + nexthop, err := r.addRouteToNonVPNIntf(testCase.prefix, wgInterface, initialNextHop) - internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) - require.NoError(t, err) - - if testCase.shouldBeRemoved { - require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway") - } else { - require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway") - } + if testCase.expectError { + require.ErrorIs(t, err, vars.ErrRouteNotAllowed) + return } + require.NoError(t, err) + t.Logf("Next hop for %s: %s", testCase.prefix, nexthop) + + // Verify the route was added and points to non-VPN interface + currentNextHop, err := GetNextHop(testCase.prefix.Addr()) + require.NoError(t, err) + assert.NotEqual(t, wgInterface.Name(), currentNextHop.Intf.Name, "Route should not point to VPN interface") + + err = r.removeFromRouteTable(testCase.prefix, nexthop) + assert.NoError(t, err) }) } } func TestGetNextHop(t *testing.T) { - if runtime.GOOS == "freebsd" { - t.Skip("skipping on freebsd") - } - nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) + defaultNh, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) } - if !nexthop.IP.IsValid() { + if !defaultNh.IP.IsValid() { t.Fatal("should return a gateway") } addresses, err := net.InterfaceAddrs() @@ -133,7 +400,6 @@ func TestGetNextHop(t *testing.T) { t.Fatal("shouldn't return error when fetching interface addresses: ", err) } - var testingIP string var testingPrefix netip.Prefix for _, address := range addresses { if address.Network() != "ip+net" { @@ -141,213 +407,23 @@ func TestGetNextHop(t *testing.T) { } prefix := netip.MustParsePrefix(address.String()) if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() { - testingIP = prefix.Addr().String() testingPrefix = prefix.Masked() break } } - localIP, err := GetNextHop(testingPrefix.Addr()) + nh, err := GetNextHop(testingPrefix.Addr()) if err != nil { t.Fatal("shouldn't return error: ", err) } - if !localIP.IP.IsValid() { + if nh.Intf == nil { t.Fatal("should return a gateway for local network") } - if localIP.IP.String() == nexthop.IP.String() { - t.Fatal("local IP should not match with gateway IP") + if nh.IP.String() == defaultNh.IP.String() { + t.Fatal("next hop IP should not match with default gateway IP") } - if localIP.IP.String() != testingIP { - t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String()) - } -} - -func TestAddExistAndRemoveRoute(t *testing.T) { - defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) - t.Log("defaultNexthop: ", defaultNexthop) - if err != nil { - t.Fatal("shouldn't return error when fetching the gateway: ", err) - } - testCases := []struct { - name string - prefix netip.Prefix - preExistingPrefix netip.Prefix - shouldAddRoute bool - }{ - { - name: "Should Add And Remove random Route", - prefix: netip.MustParsePrefix("99.99.99.99/32"), - shouldAddRoute: true, - }, - { - name: "Should Not Add Route if overlaps with default gateway", - prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"), - shouldAddRoute: false, - }, - { - name: "Should Add Route if bigger network exists", - prefix: netip.MustParsePrefix("100.100.100.0/24"), - preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), - shouldAddRoute: true, - }, - { - name: "Should Add Route if smaller network exists", - prefix: netip.MustParsePrefix("100.100.0.0/16"), - preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"), - shouldAddRoute: true, - }, - { - name: "Should Not Add Route if same network exists", - prefix: netip.MustParsePrefix("100.100.0.0/16"), - preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), - shouldAddRoute: false, - }, - } - - for n, testCase := range testCases { - - var buf bytes.Buffer - log.SetOutput(&buf) - defer func() { - log.SetOutput(os.Stderr) - }() - t.Run(testCase.name, func(t *testing.T) { - t.Setenv("NB_USE_LEGACY_ROUTING", "true") - t.Setenv("NB_DISABLE_ROUTE_CACHE", "true") - - peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() - if err != nil { - t.Fatal(err) - } - opts := iface.WGIFaceOpts{ - IFaceName: fmt.Sprintf("utun53%d", n), - Address: "100.65.75.2/24", - WGPort: 33100, - WGPrivKey: peerPrivateKey.String(), - MTU: iface.DefaultMTU, - TransportNet: newNet, - } - wgInterface, err := iface.NewWGIFace(opts) - require.NoError(t, err, "should create testing WGIface interface") - defer wgInterface.Close() - - err = wgInterface.Create() - require.NoError(t, err, "should create testing wireguard interface") - - index, err := net.InterfaceByName(wgInterface.Name()) - require.NoError(t, err, "InterfaceByName should not return err") - intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} - - r := NewSysOps(wgInterface, nil) - - // Prepare the environment - if testCase.preExistingPrefix.IsValid() { - err := r.AddVPNRoute(testCase.preExistingPrefix, intf) - require.NoError(t, err, "should not return err when adding pre-existing route") - } - - // Add the route - err = r.AddVPNRoute(testCase.prefix, intf) - require.NoError(t, err, "should not return err when adding route") - - if testCase.shouldAddRoute { - // test if route exists after adding - ok, err := existsInRouteTable(testCase.prefix) - require.NoError(t, err, "should not return err") - require.True(t, ok, "route should exist") - - // remove route again if added - err = r.RemoveVPNRoute(testCase.prefix, intf) - require.NoError(t, err, "should not return err") - } - - // route should either not have been added or should have been removed - // In case of already existing route, it should not have been added (but still exist) - ok, err := existsInRouteTable(testCase.prefix) - t.Log("Buffer string: ", buf.String()) - require.NoError(t, err, "should not return err") - - if !strings.Contains(buf.String(), "because it already exists") { - require.False(t, ok, "route should not exist") - } - }) - } -} - -func TestIsSubRange(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var subRangeAddressPrefixes []netip.Prefix - var nonSubRangeAddressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 { - p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1) - subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2) - nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked()) - } - } - - for _, prefix := range subRangeAddressPrefixes { - isSubRangePrefix, err := isSubRange(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address is sub-range: ", err) - } - if !isSubRangePrefix { - t.Fatalf("address %s should be sub-range of an existing route in the table", prefix) - } - } - - for _, prefix := range nonSubRangeAddressPrefixes { - isSubRangePrefix, err := isSubRange(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address is sub-range: ", err) - } - if isSubRangePrefix { - t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix) - } - } -} - -func TestExistsInRouteTable(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var addressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - - switch { - case p.Addr().Is6(): - continue - // Windows sometimes has hidden interface link local addrs that don't turn up on any interface - case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast(): - continue - // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence - case runtime.GOOS == "linux" && p.Addr().IsLoopback(): - continue - // FreeBSD loopback 127/8 is not added to the routing table - case runtime.GOOS == "freebsd" && p.Addr().IsLoopback(): - continue - default: - addressPrefixes = append(addressPrefixes, p.Masked()) - } - } - - for _, prefix := range addressPrefixes { - exists, err := existsInRouteTable(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address exists in route table: ", err) - } - if !exists { - t.Fatalf("address %s should exist in route table", prefix) - } + if nh.Intf.Name != defaultNh.Intf.Name { + t.Fatalf("next hop interface name should match with default gateway interface name, got: %s, want: %s", nh.Intf.Name, defaultNh.Intf.Name) } } @@ -384,11 +460,16 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) { t.Helper() - err := r.AddVPNRoute(prefix, intf) - require.NoError(t, err, "addVPNRoute should not return err") + if err := r.AddVPNRoute(prefix, intf); err != nil { + if !errors.Is(err, syscall.EEXIST) && !errors.Is(err, vars.ErrRouteNotAllowed) { + t.Fatalf("addVPNRoute should not return err: %v", err) + } + t.Logf("addVPNRoute %v returned: %v", prefix, err) + } t.Cleanup(func() { - err = r.RemoveVPNRoute(prefix, intf) - assert.NoError(t, err, "removeVPNRoute should not return err") + if err := r.RemoveVPNRoute(prefix, intf); err != nil && !errors.Is(err, vars.ErrRouteNotAllowed) { + t.Fatalf("removeVPNRoute should not return err: %v", err) + } }) } @@ -422,28 +503,10 @@ func setupTestEnv(t *testing.T) { // 10.10.0.0/24 more specific route exists in vpn table setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf) - // 127.0.10.0/24 more specific route exists in vpn table - setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf) - // unique route in vpn table setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf) } -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() { - return - } - - prefixNexthop, err := GetNextHop(prefix.Addr()) - require.NoError(t, err, "GetNextHop should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP") - } -} - func TestIsVpnRoute(t *testing.T) { tests := []struct { name string diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 59b6346c6..b48cfa242 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -149,6 +149,10 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro } func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + if err := r.validateRoute(prefix); err != nil { + return err + } + if !nbnet.AdvancedRouting() { return r.genericAddVPNRoute(prefix, intf) } @@ -172,6 +176,10 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { } func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + if err := r.validateRoute(prefix); err != nil { + return err + } + if !nbnet.AdvancedRouting() { return r.genericRemoveVPNRoute(prefix, intf) } @@ -219,7 +227,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) { ones, _ := route.Dst.Mask.Size() - prefix := netip.PrefixFrom(addr, ones) + prefix := netip.PrefixFrom(addr.Unmap(), ones) if prefix.IsValid() { prefixList = append(prefixList, prefix) } @@ -247,7 +255,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { return fmt.Errorf("add gateway and device: %w", err) } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { + if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) { return fmt.Errorf("netlink add route: %w", err) } @@ -270,7 +278,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error { Dst: ipNet, } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { + if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) { return fmt.Errorf("netlink add unreachable route: %w", err) } diff --git a/client/internal/routemanager/systemops/systemops_linux_test.go b/client/internal/routemanager/systemops/systemops_linux_test.go index f0d7472dc..880296d91 100644 --- a/client/internal/routemanager/systemops/systemops_linux_test.go +++ b/client/internal/routemanager/systemops/systemops_linux_test.go @@ -19,7 +19,6 @@ import ( ) var expectedVPNint = "wgtest0" -var expectedLoopbackInt = "lo" var expectedExternalInt = "dummyext0" var expectedInternalInt = "dummyint0" @@ -31,12 +30,6 @@ func init() { dialer: &net.Dialer{}, expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), }, - { - name: "To more specific route (local) without custom dialer via physical interface", - expectedInterface: expectedLoopbackInt, - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), - }, }...) } diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 3b52fc7af..59581255f 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -11,10 +11,16 @@ import ( ) func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + if err := r.validateRoute(prefix); err != nil { + return err + } return r.genericAddVPNRoute(prefix, intf) } func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + if err := r.validateRoute(prefix); err != nil { + return err + } return r.genericRemoveVPNRoute(prefix, intf) } diff --git a/client/internal/routemanager/systemops/systemops_test.go b/client/internal/routemanager/systemops/systemops_test.go new file mode 100644 index 000000000..1d1f78830 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_test.go @@ -0,0 +1,268 @@ +package systemops + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/routemanager/notifier" + "github.com/netbirdio/netbird/client/internal/routemanager/vars" +) + +type mockWGIface struct { + address wgaddr.Address + name string +} + +func (m *mockWGIface) Address() wgaddr.Address { + return m.address +} + +func (m *mockWGIface) Name() string { + return m.name +} + +func TestSysOps_validateRoute(t *testing.T) { + wgNetwork := netip.MustParsePrefix("10.0.0.0/24") + mockWG := &mockWGIface{ + address: wgaddr.Address{ + IP: wgNetwork.Addr(), + Network: wgNetwork, + }, + name: "wg0", + } + + sysOps := &SysOps{ + wgInterface: mockWG, + notifier: ¬ifier.Notifier{}, + } + + tests := []struct { + name string + prefix string + expectError bool + }{ + // Valid routes + { + name: "valid IPv4 route", + prefix: "192.168.1.0/24", + expectError: false, + }, + { + name: "valid IPv6 route", + prefix: "2001:db8::/32", + expectError: false, + }, + { + name: "valid single IPv4 host", + prefix: "8.8.8.8/32", + expectError: false, + }, + { + name: "valid single IPv6 host", + prefix: "2001:4860:4860::8888/128", + expectError: false, + }, + + // Invalid routes - loopback + { + name: "IPv4 loopback", + prefix: "127.0.0.1/32", + expectError: true, + }, + { + name: "IPv6 loopback", + prefix: "::1/128", + expectError: true, + }, + + // Invalid routes - link-local unicast + { + name: "IPv4 link-local unicast", + prefix: "169.254.1.1/32", + expectError: true, + }, + { + name: "IPv6 link-local unicast", + prefix: "fe80::1/128", + expectError: true, + }, + + // Invalid routes - multicast + { + name: "IPv4 multicast", + prefix: "224.0.0.1/32", + expectError: true, + }, + { + name: "IPv6 multicast", + prefix: "ff02::1/128", + expectError: true, + }, + + // Invalid routes - link-local multicast + { + name: "IPv4 link-local multicast", + prefix: "224.0.0.0/24", + expectError: true, + }, + { + name: "IPv6 link-local multicast", + prefix: "ff02::/16", + expectError: true, + }, + + // Invalid routes - interface-local multicast (IPv6 only) + { + name: "IPv6 interface-local multicast", + prefix: "ff01::1/128", + expectError: true, + }, + + // Invalid routes - overlaps with WG interface network + { + name: "overlaps with WG network - exact match", + prefix: "10.0.0.0/24", + expectError: true, + }, + { + name: "overlaps with WG network - subset", + prefix: "10.0.0.1/32", + expectError: true, + }, + { + name: "overlaps with WG network - host in range", + prefix: "10.0.0.100/32", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prefix, err := netip.ParsePrefix(tt.prefix) + require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix) + + err = sysOps.validateRoute(prefix) + + if tt.expectError { + require.Error(t, err, "validateRoute() expected error for %s", tt.prefix) + assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s", tt.prefix) + } else { + assert.NoError(t, err, "validateRoute() expected no error for %s", tt.prefix) + } + }) + } +} + +func TestSysOps_validateRoute_SubnetOverlap(t *testing.T) { + wgNetwork := netip.MustParsePrefix("192.168.100.0/24") + mockWG := &mockWGIface{ + address: wgaddr.Address{ + IP: wgNetwork.Addr(), + Network: wgNetwork, + }, + name: "wg0", + } + + sysOps := &SysOps{ + wgInterface: mockWG, + notifier: ¬ifier.Notifier{}, + } + + tests := []struct { + name string + prefix string + expectError bool + description string + }{ + { + name: "identical subnet", + prefix: "192.168.100.0/24", + expectError: true, + description: "exact same network as WG interface", + }, + { + name: "broader subnet containing WG network", + prefix: "192.168.0.0/16", + expectError: false, + description: "broader network that contains WG network should be allowed", + }, + { + name: "host within WG network", + prefix: "192.168.100.50/32", + expectError: true, + description: "specific host within WG network", + }, + { + name: "subnet within WG network", + prefix: "192.168.100.128/25", + expectError: true, + description: "smaller subnet within WG network", + }, + { + name: "adjacent subnet - same /23", + prefix: "192.168.101.0/24", + expectError: false, + description: "adjacent subnet, no overlap", + }, + { + name: "adjacent subnet - different /16", + prefix: "192.167.100.0/24", + expectError: false, + description: "different network, no overlap", + }, + { + name: "WG network broadcast address", + prefix: "192.168.100.255/32", + expectError: true, + description: "broadcast address of WG network", + }, + { + name: "WG network first usable", + prefix: "192.168.100.1/32", + expectError: true, + description: "first usable address in WG network", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prefix, err := netip.ParsePrefix(tt.prefix) + require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix) + + err = sysOps.validateRoute(prefix) + + if tt.expectError { + require.Error(t, err, "validateRoute() expected error for %s (%s)", tt.prefix, tt.description) + assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s (%s)", tt.prefix, tt.description) + } else { + assert.NoError(t, err, "validateRoute() expected no error for %s (%s)", tt.prefix, tt.description) + } + }) + } +} + +func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) { + wgNetwork := netip.MustParsePrefix("10.0.0.0/24") + mockWG := &mockWGIface{ + address: wgaddr.Address{ + IP: wgNetwork.Addr(), + Network: wgNetwork, + }, + name: "wt0", + } + + sysOps := &SysOps{ + wgInterface: mockWG, + notifier: ¬ifier.Notifier{}, + } + + var invalidPrefix netip.Prefix + err := sysOps.validateRoute(invalidPrefix) + + require.Error(t, err, "validateRoute() expected error for invalid prefix") + assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for invalid prefix") +} diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index 0f8f2a341..f284e131b 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -3,15 +3,19 @@ package systemops import ( + "errors" "fmt" "net" "net/netip" - "os/exec" - "strings" + "strconv" + "syscall" "time" + "unsafe" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" + "golang.org/x/net/route" + "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" @@ -26,48 +30,16 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { } func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { - return r.routeCmd("add", prefix, nexthop) + return r.routeSocket(unix.RTM_ADD, prefix, nexthop) } func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { - return r.routeCmd("delete", prefix, nexthop) + return r.routeSocket(unix.RTM_DELETE, prefix, nexthop) } -func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error { - inet := "-inet" - if prefix.Addr().Is6() { - inet = "-inet6" - } - - network := prefix.String() - if prefix.IsSingleIP() { - network = prefix.Addr().String() - } - - args := []string{"-n", action, inet, network} - if nexthop.IP.IsValid() { - args = append(args, nexthop.IP.Unmap().String()) - } else if nexthop.Intf != nil { - args = append(args, "-interface", nexthop.Intf.Name) - } - - if err := retryRouteCmd(args); err != nil { - return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) - } - return nil -} - -func retryRouteCmd(args []string) error { - operation := func() error { - out, err := exec.Command("route", args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) - // https://github.com/golang/go/issues/45736 - if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") { - return err - } else if err != nil { - return backoff.Permanent(err) - } - return nil +func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error { + if !prefix.IsValid() { + return fmt.Errorf("invalid prefix: %s", prefix) } expBackOff := backoff.NewExponentialBackOff() @@ -75,9 +47,157 @@ func retryRouteCmd(args []string) error { expBackOff.MaxInterval = 500 * time.Millisecond expBackOff.MaxElapsedTime = 1 * time.Second - err := backoff.Retry(operation, expBackOff) - if err != nil { - return fmt.Errorf("route cmd retry failed: %w", err) + if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil { + a := "add" + if action == unix.RTM_DELETE { + a = "remove" + } + return fmt.Errorf("%s route for %s: %w", a, prefix, err) } return nil } + +func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error { + operation := func() error { + fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) + if err != nil { + return fmt.Errorf("open routing socket: %w", err) + } + defer func() { + if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) { + log.Warnf("failed to close routing socket: %v", err) + } + }() + + msg, err := r.buildRouteMessage(action, prefix, nexthop) + if err != nil { + return backoff.Permanent(fmt.Errorf("build route message: %w", err)) + } + + msgBytes, err := msg.Marshal() + if err != nil { + return backoff.Permanent(fmt.Errorf("marshal route message: %w", err)) + } + + if _, err = unix.Write(fd, msgBytes); err != nil { + if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) { + return fmt.Errorf("write: %w", err) + } + return backoff.Permanent(fmt.Errorf("write: %w", err)) + } + + respBuf := make([]byte, 2048) + n, err := unix.Read(fd, respBuf) + if err != nil { + return backoff.Permanent(fmt.Errorf("read route response: %w", err)) + } + + if n > 0 { + if err := r.parseRouteResponse(respBuf[:n]); err != nil { + return backoff.Permanent(err) + } + } + + return nil + } + return operation +} + +func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { + msg = &route.RouteMessage{ + Type: action, + Flags: unix.RTF_UP, + Version: unix.RTM_VERSION, + Seq: 1, + } + + const numAddrs = unix.RTAX_NETMASK + 1 + addrs := make([]route.Addr, numAddrs) + + addrs[unix.RTAX_DST], err = addrToRouteAddr(prefix.Addr()) + if err != nil { + return nil, fmt.Errorf("build destination address for %s: %w", prefix.Addr(), err) + } + + if prefix.IsSingleIP() { + msg.Flags |= unix.RTF_HOST + } else { + addrs[unix.RTAX_NETMASK], err = prefixToRouteNetmask(prefix) + if err != nil { + return nil, fmt.Errorf("build netmask for %s: %w", prefix, err) + } + } + + if nexthop.IP.IsValid() { + msg.Flags |= unix.RTF_GATEWAY + addrs[unix.RTAX_GATEWAY], err = addrToRouteAddr(nexthop.IP.Unmap()) + if err != nil { + return nil, fmt.Errorf("build gateway IP address for %s: %w", nexthop.IP, err) + } + } else if nexthop.Intf != nil { + msg.Index = nexthop.Intf.Index + addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{ + Index: nexthop.Intf.Index, + Name: nexthop.Intf.Name, + } + } + + msg.Addrs = addrs + return msg, nil +} + +func (r *SysOps) parseRouteResponse(buf []byte) error { + if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) { + return nil + } + + rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) + if rtMsg.Errno != 0 { + return fmt.Errorf("parse: %d", rtMsg.Errno) + } + + return nil +} + +// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr). +func addrToRouteAddr(addr netip.Addr) (route.Addr, error) { + if addr.Is4() { + return &route.Inet4Addr{IP: addr.As4()}, nil + } + + if addr.Zone() == "" { + return &route.Inet6Addr{IP: addr.As16()}, nil + } + + var zone int + // zone can be either a numeric zone ID or an interface name. + if z, err := strconv.Atoi(addr.Zone()); err == nil { + zone = z + } else { + iface, err := net.InterfaceByName(addr.Zone()) + if err != nil { + return nil, fmt.Errorf("resolve zone '%s': %w", addr.Zone(), err) + } + zone = iface.Index + } + return &route.Inet6Addr{IP: addr.As16(), ZoneID: zone}, nil +} + +func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) { + bits := prefix.Bits() + if prefix.Addr().Is4() { + m := net.CIDRMask(bits, 32) + var maskBytes [4]byte + copy(maskBytes[:], m) + return &route.Inet4Addr{IP: maskBytes}, nil + } + + if prefix.Addr().Is6() { + m := net.CIDRMask(bits, 128) + var maskBytes [16]byte + copy(maskBytes[:], m) + return &route.Inet6Addr{IP: maskBytes}, nil + } + + return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String()) +} diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index f66161595..11eaa435e 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -1,5 +1,3 @@ -//go:build windows - package systemops import ( @@ -9,9 +7,8 @@ import ( "net" "net/netip" "os" - "os/exec" + "runtime/debug" "strconv" - "strings" "sync" "syscall" "time" @@ -21,11 +18,12 @@ import ( "github.com/yusufpapurcu/wmi" "golang.org/x/sys/windows" - "github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) +const InfiniteLifetime = 0xffffffff + type RouteUpdateType int // RouteUpdate represents a change in the routing table. @@ -58,9 +56,13 @@ type MSFT_NetRoute struct { AddressFamily uint16 } -// MIB_IPFORWARD_ROW2 is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2 +// luid represents a locally unique identifier for network interfaces +type luid uint64 + +// MIB_IPFORWARD_ROW2 represents a route entry in the routing table. +// It is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2 type MIB_IPFORWARD_ROW2 struct { - InterfaceLuid uint64 + InterfaceLuid luid InterfaceIndex uint32 DestinationPrefix IP_ADDRESS_PREFIX NextHop SOCKADDR_INET_NEXTHOP @@ -108,9 +110,14 @@ type SOCKADDR_INET_NEXTHOP struct { type MIB_NOTIFICATION_TYPE int32 var ( - modiphlpapi = windows.NewLazyDLL("iphlpapi.dll") - procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2") - procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2") + modiphlpapi = windows.NewLazyDLL("iphlpapi.dll") + procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2") + procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2") + procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2") + procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2") + procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2") + procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry") + procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid") prefixList []netip.Prefix lastUpdate time.Time @@ -139,6 +146,8 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { } func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + log.Debugf("Adding route to %s via %s", prefix, nexthop) + // if we don't have an interface but a zone, extract the interface index from the zone if nexthop.IP.Zone() != "" && nexthop.Intf == nil { zone, err := strconv.Atoi(nexthop.IP.Zone()) if err != nil { @@ -147,23 +156,187 @@ func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { nexthop.Intf = &net.Interface{Index: zone} } - return addRouteCmd(prefix, nexthop) + return addRoute(prefix, nexthop) } func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { - args := []string{"delete", prefix.String()} - if nexthop.IP.IsValid() { - ip := nexthop.IP.WithZone("") - args = append(args, ip.Unmap().String()) + log.Debugf("Removing route to %s via %s", prefix, nexthop) + return deleteRoute(prefix, nexthop) +} + +// setupRouteEntry prepares a route entry with common configuration +func setupRouteEntry(prefix netip.Prefix, nexthop Nexthop) (*MIB_IPFORWARD_ROW2, error) { + route := &MIB_IPFORWARD_ROW2{} + + initializeIPForwardEntry(route) + + // Convert interface index to luid if interface is specified + if nexthop.Intf != nil { + var luid luid + if err := convertInterfaceIndexToLUID(uint32(nexthop.Intf.Index), &luid); err != nil { + return nil, fmt.Errorf("convert interface index to luid: %w", err) + } + route.InterfaceLuid = luid + route.InterfaceIndex = uint32(nexthop.Intf.Index) } - routeCmd := uspfilter.GetSystem32Command("route") + if err := setDestinationPrefix(&route.DestinationPrefix, prefix); err != nil { + return nil, fmt.Errorf("set destination prefix: %w", err) + } - out, err := exec.Command(routeCmd, args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) + if nexthop.IP.IsValid() { + if err := setNextHop(&route.NextHop, nexthop.IP); err != nil { + return nil, fmt.Errorf("set next hop: %w", err) + } + } - if err != nil { - return fmt.Errorf("remove route: %w", err) + return route, nil +} + +// addRoute adds a route using Windows iphelper APIs +func addRoute(prefix netip.Prefix, nexthop Nexthop) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in addRoute: %v, stack trace: %s", r, debug.Stack()) + } + }() + + route, setupErr := setupRouteEntry(prefix, nexthop) + if setupErr != nil { + return fmt.Errorf("setup route entry: %w", setupErr) + } + + route.Metric = 1 + route.ValidLifetime = InfiniteLifetime + route.PreferredLifetime = InfiniteLifetime + + return createIPForwardEntry2(route) +} + +// deleteRoute deletes a route using Windows iphelper APIs +func deleteRoute(prefix netip.Prefix, nexthop Nexthop) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in deleteRoute: %v, stack trace: %s", r, debug.Stack()) + } + }() + + route, setupErr := setupRouteEntry(prefix, nexthop) + if setupErr != nil { + return fmt.Errorf("setup route entry: %w", setupErr) + } + + if err := getIPForwardEntry2(route); err != nil { + return fmt.Errorf("get route entry: %w", err) + } + + return deleteIPForwardEntry2(route) +} + +// setDestinationPrefix sets the destination prefix in the route structure +func setDestinationPrefix(prefix *IP_ADDRESS_PREFIX, dest netip.Prefix) error { + addr := dest.Addr() + prefix.PrefixLength = uint8(dest.Bits()) + + if addr.Is4() { + prefix.Prefix.sin6_family = windows.AF_INET + ip4 := addr.As4() + binary.BigEndian.PutUint32(prefix.Prefix.data[:4], + uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3])) + return nil + } + + if addr.Is6() { + prefix.Prefix.sin6_family = windows.AF_INET6 + ip6 := addr.As16() + copy(prefix.Prefix.data[4:20], ip6[:]) + + if zone := addr.Zone(); zone != "" { + if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil { + binary.BigEndian.PutUint32(prefix.Prefix.data[20:24], uint32(scopeID)) + } + } + return nil + } + + return fmt.Errorf("invalid address family") +} + +// setNextHop sets the next hop address in the route structure +func setNextHop(nextHop *SOCKADDR_INET_NEXTHOP, addr netip.Addr) error { + if addr.Is4() { + nextHop.sin6_family = windows.AF_INET + ip4 := addr.As4() + binary.BigEndian.PutUint32(nextHop.data[:4], + uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3])) + return nil + } + + if addr.Is6() { + nextHop.sin6_family = windows.AF_INET6 + ip6 := addr.As16() + copy(nextHop.data[4:20], ip6[:]) + + // Handle zone if present + if zone := addr.Zone(); zone != "" { + if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil { + binary.BigEndian.PutUint32(nextHop.data[20:24], uint32(scopeID)) + } + } + return nil + } + + return fmt.Errorf("invalid address family") +} + +// Windows API wrappers +func createIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error { + r1, _, e1 := syscall.SyscallN(procCreateIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route))) + if r1 != 0 { + if e1 != 0 { + return fmt.Errorf("CreateIpForwardEntry2: %w", e1) + } + return fmt.Errorf("CreateIpForwardEntry2: code %d", r1) + } + return nil +} + +func deleteIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error { + r1, _, e1 := syscall.SyscallN(procDeleteIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route))) + if r1 != 0 { + if e1 != 0 { + return fmt.Errorf("DeleteIpForwardEntry2: %w", e1) + } + return fmt.Errorf("DeleteIpForwardEntry2: code %d", r1) + } + return nil +} + +func getIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error { + r1, _, e1 := syscall.SyscallN(procGetIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route))) + if r1 != 0 { + if e1 != 0 { + return fmt.Errorf("GetIpForwardEntry2: %w", e1) + } + return fmt.Errorf("GetIpForwardEntry2: code %d", r1) + } + return nil +} + +// https://learn.microsoft.com/en-us/windows/win32/api/netioapi/nf-netioapi-initializeipforwardentry +func initializeIPForwardEntry(route *MIB_IPFORWARD_ROW2) { + // Does not return anything. Trying to handle the error might return an uninitialized value. + _, _, _ = syscall.SyscallN(procInitializeIpForwardEntry.Addr(), uintptr(unsafe.Pointer(route))) +} + +func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *luid) error { + r1, _, e1 := syscall.SyscallN(procConvertInterfaceIndexToLuid.Addr(), + uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID))) + if r1 != 0 { + if e1 != 0 { + return fmt.Errorf("ConvertInterfaceIndexToLuid: %w", e1) + } + return fmt.Errorf("ConvertInterfaceIndexToLuid: code %d", r1) } return nil } @@ -319,7 +492,7 @@ func cancelMibChangeNotify2(handle windows.Handle) error { } // GetRoutesFromTable returns the current routing table from with prefixes only. -// It ccaches the result for 2 seconds to avoid blocking the caller. +// It caches the result for 2 seconds to avoid blocking the caller. func GetRoutesFromTable() ([]netip.Prefix, error) { mux.Lock() defer mux.Unlock() @@ -388,35 +561,6 @@ func GetRoutes() ([]Route, error) { return routes, nil } -func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error { - args := []string{"add", prefix.String()} - - if nexthop.IP.IsValid() { - ip := nexthop.IP.WithZone("") - args = append(args, ip.Unmap().String()) - } else { - addr := "0.0.0.0" - if prefix.Addr().Is6() { - addr = "::" - } - args = append(args, addr) - } - - if nexthop.Intf != nil { - args = append(args, "if", strconv.Itoa(nexthop.Intf.Index)) - } - - routeCmd := uspfilter.GetSystem32Command("route") - - out, err := exec.Command(routeCmd, args...).CombinedOutput() - log.Tracef("route %s: %s", strings.Join(args, " "), out) - if err != nil { - return fmt.Errorf("route add: %w", err) - } - - return nil -} - func isCacheDisabled() bool { return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true" } diff --git a/client/internal/routemanager/systemops/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go index 19b006017..523bd0b0d 100644 --- a/client/internal/routemanager/systemops/systemops_windows_test.go +++ b/client/internal/routemanager/systemops/systemops_windows_test.go @@ -5,18 +5,23 @@ import ( "encoding/json" "fmt" "net" + "net/netip" "os/exec" "strings" "testing" "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" nbnet "github.com/netbirdio/netbird/util/net" ) -var expectedExtInt = "Ethernet1" +var ( + expectedExternalInt = "Ethernet1" + expectedVPNint = "wgtest0" +) type RouteInfo struct { NextHop string `json:"nexthop"` @@ -43,8 +48,6 @@ type testCase struct { dialer dialer } -var expectedVPNint = "wgtest0" - var testCases = []testCase{ { name: "To external host without custom dialer via vpn", @@ -52,14 +55,14 @@ var testCases = []testCase{ expectedSourceIP: "100.64.0.1", expectedDestPrefix: "128.0.0.0/1", expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", + expectedInterface: expectedVPNint, dialer: &net.Dialer{}, }, { name: "To external host with custom dialer via physical interface", destination: "192.0.2.1:53", expectedDestPrefix: "192.0.2.1/32", - expectedInterface: expectedExtInt, + expectedInterface: expectedExternalInt, dialer: nbnet.NewDialer(), }, @@ -67,24 +70,15 @@ var testCases = []testCase{ name: "To duplicate internal route with custom dialer via physical interface", destination: "10.0.0.2:53", expectedDestPrefix: "10.0.0.2/32", - expectedInterface: expectedExtInt, + expectedInterface: expectedExternalInt, dialer: nbnet.NewDialer(), }, - { - name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", - expectedSourceIP: "127.0.0.1", - expectedDestPrefix: "10.0.0.0/8", - expectedNextHop: "0.0.0.0", - expectedInterface: "Loopback Pseudo-Interface 1", - dialer: &net.Dialer{}, - }, { name: "To unique vpn route with custom dialer via physical interface", destination: "172.16.0.2:53", expectedDestPrefix: "172.16.0.2/32", - expectedInterface: expectedExtInt, + expectedInterface: expectedExternalInt, dialer: nbnet.NewDialer(), }, { @@ -93,7 +87,7 @@ var testCases = []testCase{ expectedSourceIP: "100.64.0.1", expectedDestPrefix: "172.16.0.0/12", expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", + expectedInterface: expectedVPNint, dialer: &net.Dialer{}, }, @@ -103,22 +97,14 @@ var testCases = []testCase{ expectedSourceIP: "100.64.0.1", expectedDestPrefix: "10.10.0.0/24", expectedNextHop: "0.0.0.0", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - }, - - { - name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.2:53", - expectedSourceIP: "127.0.0.1", - expectedDestPrefix: "127.0.0.0/8", - expectedNextHop: "0.0.0.0", - expectedInterface: "Loopback Pseudo-Interface 1", + expectedInterface: expectedVPNint, dialer: &net.Dialer{}, }, } func TestRouting(t *testing.T) { + log.SetLevel(log.DebugLevel) + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { setupTestEnv(t) @@ -129,7 +115,7 @@ func TestRouting(t *testing.T) { require.NoError(t, err, "Failed to fetch interface IP") output := testRoute(t, tc.destination, tc.dialer) - if tc.expectedInterface == expectedExtInt { + if tc.expectedInterface == expectedExternalInt { verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) } else { verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) @@ -242,19 +228,23 @@ func setupDummyInterfacesAndRoutes(t *testing.T) { func addDummyRoute(t *testing.T, dstCIDR string) { t.Helper() - script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR) - - output, err := exec.Command("powershell", "-Command", script).CombinedOutput() + prefix, err := netip.ParsePrefix(dstCIDR) if err != nil { - t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output) - t.FailNow() + t.Fatalf("Failed to parse destination CIDR %s: %v", dstCIDR, err) + } + + nexthop := Nexthop{ + Intf: &net.Interface{Index: 1}, + } + + if err = addRoute(prefix, nexthop); err != nil { + t.Fatalf("Failed to add dummy route: %v", err) } t.Cleanup(func() { - script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR) - output, err := exec.Command("powershell", "-Command", script).CombinedOutput() + err := deleteRoute(prefix, nexthop) if err != nil { - t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output) + t.Logf("Failed to remove dummy route: %v", err) } }) } diff --git a/client/server/trace.go b/client/server/trace.go index 8b9d375f3..e4ac91487 100644 --- a/client/server/trace.go +++ b/client/server/trace.go @@ -3,11 +3,11 @@ package server import ( "context" "fmt" - "net" "net/netip" fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" ) @@ -19,81 +19,32 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) ( s.mutex.Lock() defer s.mutex.Unlock() - if s.connectClient == nil { - return nil, fmt.Errorf("connect client not initialized") - } - engine := s.connectClient.Engine() - if engine == nil { - return nil, fmt.Errorf("engine not initialized") + tracer, engine, err := s.getPacketTracer() + if err != nil { + return nil, err } - fwManager := engine.GetFirewallManager() - if fwManager == nil { - return nil, fmt.Errorf("firewall manager not initialized") + srcAddr, err := s.parseAddress(req.GetSourceIp(), engine) + if err != nil { + return nil, fmt.Errorf("invalid source IP address: %w", err) } - tracer, ok := fwManager.(packetTracer) - if !ok { - return nil, fmt.Errorf("firewall manager does not support packet tracing") + dstAddr, err := s.parseAddress(req.GetDestinationIp(), engine) + if err != nil { + return nil, fmt.Errorf("invalid destination IP address: %w", err) } - srcIP := net.ParseIP(req.GetSourceIp()) - if req.GetSourceIp() == "self" { - srcIP = engine.GetWgAddr() + protocol, err := s.parseProtocol(req.GetProtocol()) + if err != nil { + return nil, err } - srcAddr, ok := netip.AddrFromSlice(srcIP) - if !ok { - return nil, fmt.Errorf("invalid source IP address") + direction, err := s.parseDirection(req.GetDirection()) + if err != nil { + return nil, err } - dstIP := net.ParseIP(req.GetDestinationIp()) - if req.GetDestinationIp() == "self" { - dstIP = engine.GetWgAddr() - } - - dstAddr, ok := netip.AddrFromSlice(dstIP) - if !ok { - return nil, fmt.Errorf("invalid source IP address") - } - - if srcIP == nil || dstIP == nil { - return nil, fmt.Errorf("invalid IP address") - } - - var tcpState *uspfilter.TCPState - if flags := req.GetTcpFlags(); flags != nil { - tcpState = &uspfilter.TCPState{ - SYN: flags.GetSyn(), - ACK: flags.GetAck(), - FIN: flags.GetFin(), - RST: flags.GetRst(), - PSH: flags.GetPsh(), - URG: flags.GetUrg(), - } - } - - var dir fw.RuleDirection - switch req.GetDirection() { - case "in": - dir = fw.RuleDirectionIN - case "out": - dir = fw.RuleDirectionOUT - default: - return nil, fmt.Errorf("invalid direction") - } - - var protocol fw.Protocol - switch req.GetProtocol() { - case "tcp": - protocol = fw.ProtocolTCP - case "udp": - protocol = fw.ProtocolUDP - case "icmp": - protocol = fw.ProtocolICMP - default: - return nil, fmt.Errorf("invalid protocolcol") - } + tcpState := s.parseTCPFlags(req.GetTcpFlags()) builder := &uspfilter.PacketBuilder{ SrcIP: srcAddr, @@ -101,16 +52,96 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) ( Protocol: protocol, SrcPort: uint16(req.GetSourcePort()), DstPort: uint16(req.GetDestinationPort()), - Direction: dir, + Direction: direction, TCPState: tcpState, ICMPType: uint8(req.GetIcmpType()), ICMPCode: uint8(req.GetIcmpCode()), } + trace, err := tracer.TracePacketFromBuilder(builder) if err != nil { return nil, fmt.Errorf("trace packet: %w", err) } + return s.buildTraceResponse(trace), nil +} + +func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) { + if s.connectClient == nil { + return nil, nil, fmt.Errorf("connect client not initialized") + } + + engine := s.connectClient.Engine() + if engine == nil { + return nil, nil, fmt.Errorf("engine not initialized") + } + + fwManager := engine.GetFirewallManager() + if fwManager == nil { + return nil, nil, fmt.Errorf("firewall manager not initialized") + } + + tracer, ok := fwManager.(packetTracer) + if !ok { + return nil, nil, fmt.Errorf("firewall manager does not support packet tracing") + } + + return tracer, engine, nil +} + +func (s *Server) parseAddress(addr string, engine *internal.Engine) (netip.Addr, error) { + if addr == "self" { + return engine.GetWgAddr(), nil + } + + a, err := netip.ParseAddr(addr) + if err != nil { + return netip.Addr{}, err + } + + return a.Unmap(), nil +} + +func (s *Server) parseProtocol(protocol string) (fw.Protocol, error) { + switch protocol { + case "tcp": + return fw.ProtocolTCP, nil + case "udp": + return fw.ProtocolUDP, nil + case "icmp": + return fw.ProtocolICMP, nil + default: + return "", fmt.Errorf("invalid protocol") + } +} + +func (s *Server) parseDirection(direction string) (fw.RuleDirection, error) { + switch direction { + case "in": + return fw.RuleDirectionIN, nil + case "out": + return fw.RuleDirectionOUT, nil + default: + return 0, fmt.Errorf("invalid direction") + } +} + +func (s *Server) parseTCPFlags(flags *proto.TCPFlags) *uspfilter.TCPState { + if flags == nil { + return nil + } + + return &uspfilter.TCPState{ + SYN: flags.GetSyn(), + ACK: flags.GetAck(), + FIN: flags.GetFin(), + RST: flags.GetRst(), + PSH: flags.GetPsh(), + URG: flags.GetUrg(), + } +} + +func (s *Server) buildTraceResponse(trace *uspfilter.PacketTrace) *proto.TracePacketResponse { resp := &proto.TracePacketResponse{} for _, result := range trace.Results { @@ -119,10 +150,12 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) ( Message: result.Message, Allowed: result.Allowed, } + if result.ForwarderAction != nil { details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr) stage.ForwardingDetails = &details } + resp.Stages = append(resp.Stages, stage) } @@ -130,5 +163,5 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) ( resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed } - return resp, nil + return resp } diff --git a/util/net/net.go b/util/net/net.go index b573f9aeb..fdcf4ee6a 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -1,8 +1,10 @@ package net import ( + "fmt" "math/big" "net" + "net/netip" "github.com/google/uuid" ) @@ -54,11 +56,13 @@ func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString()) } -func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP { - // Calculate the last IP in the CIDR range +func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) { var endIP net.IP - for i := 0; i < len(network.IP); i++ { - endIP = append(endIP, network.IP[i]|^network.Mask[i]) + addr := network.Addr().AsSlice() + mask := net.CIDRMask(network.Bits(), len(addr)*8) + + for i := 0; i < len(addr); i++ { + endIP = append(endIP, addr[i]|^mask[i]) } // convert to big.Int @@ -70,5 +74,10 @@ func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP { resultInt := big.NewInt(0) resultInt.Sub(endInt, fromEndBig) - return resultInt.Bytes() + ip, ok := netip.AddrFromSlice(resultInt.Bytes()) + if !ok { + return netip.Addr{}, fmt.Errorf("invalid IP address from network %s", network) + } + + return ip.Unmap(), nil } diff --git a/util/net/net_test.go b/util/net/net_test.go new file mode 100644 index 000000000..e0633cb6a --- /dev/null +++ b/util/net/net_test.go @@ -0,0 +1,94 @@ +package net + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetLastIPFromNetwork(t *testing.T) { + tests := []struct { + name string + network string + fromEnd int + expected string + expectErr bool + }{ + { + name: "IPv4 /24 network - last IP (fromEnd=0)", + network: "192.168.1.0/24", + fromEnd: 0, + expected: "192.168.1.255", + }, + { + name: "IPv4 /24 network - fromEnd=1", + network: "192.168.1.0/24", + fromEnd: 1, + expected: "192.168.1.254", + }, + { + name: "IPv4 /24 network - fromEnd=5", + network: "192.168.1.0/24", + fromEnd: 5, + expected: "192.168.1.250", + }, + { + name: "IPv4 /16 network - last IP", + network: "10.0.0.0/16", + fromEnd: 0, + expected: "10.0.255.255", + }, + { + name: "IPv4 /16 network - fromEnd=256", + network: "10.0.0.0/16", + fromEnd: 256, + expected: "10.0.254.255", + }, + { + name: "IPv4 /32 network - single host", + network: "192.168.1.100/32", + fromEnd: 0, + expected: "192.168.1.100", + }, + { + name: "IPv6 /64 network - last IP", + network: "2001:db8::/64", + fromEnd: 0, + expected: "2001:db8::ffff:ffff:ffff:ffff", + }, + { + name: "IPv6 /64 network - fromEnd=1", + network: "2001:db8::/64", + fromEnd: 1, + expected: "2001:db8::ffff:ffff:ffff:fffe", + }, + { + name: "IPv6 /128 network - single host", + network: "2001:db8::1/128", + fromEnd: 0, + expected: "2001:db8::1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + network, err := netip.ParsePrefix(tt.network) + require.NoError(t, err, "Failed to parse network prefix") + + result, err := GetLastIPFromNetwork(network, tt.fromEnd) + + if tt.expectErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + + expectedIP, err := netip.ParseAddr(tt.expected) + require.NoError(t, err, "Failed to parse expected IP") + + assert.Equal(t, expectedIP, result, "IP mismatch for network %s with fromEnd=%d", tt.network, tt.fromEnd) + }) + } +}