netbird/client/internal/routemanager/systemops_linux_test.go

470 lines
13 KiB
Go
Raw Normal View History

//go:build !android
package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"os"
"strings"
"syscall"
"testing"
"time"
"github.com/gopacket/gopacket"
"github.com/gopacket/gopacket/layers"
"github.com/gopacket/gopacket/pcap"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}
func TestRoutingWithTables(t *testing.T) {
testCases := []struct {
name string
destination string
captureInterface string
dialer *net.Dialer
packetExpectation PacketExpectation
}{
{
name: "To external host without fwmark via vpn",
destination: "192.0.2.1:53",
captureInterface: "wgtest0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with fwmark via physical interface",
destination: "192.0.2.1:53",
captureInterface: "dummyext0",
dialer: nbnet.NewDialer(),
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with fwmark via physical interface",
destination: "10.0.0.1:53",
captureInterface: "dummyint0",
dialer: nbnet.NewDialer(),
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
},
{
name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence
destination: "10.0.0.1:53",
captureInterface: "dummyint0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
},
{
name: "To unique vpn route with fwmark via physical interface",
destination: "172.16.0.1:53",
captureInterface: "dummyext0",
dialer: nbnet.NewDialer(),
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53),
},
{
name: "To unique vpn route without fwmark via vpn",
destination: "172.16.0.1:53",
captureInterface: "wgtest0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53),
},
{
name: "To more specific route without fwmark via vpn interface",
destination: "10.10.0.1:53",
captureInterface: "dummyint0",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53),
},
{
name: "To more specific route (local) without fwmark via physical interface",
destination: "127.0.10.1:53",
captureInterface: "lo",
dialer: &net.Dialer{},
packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
wgIface, _, _ := setupTestEnv(t)
// default route exists in main table and vpn table
err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// 10.0.0.0/8 route exists in main table and vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// 10.10.0.0/24 more specific route exists in vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// 127.0.10.0/24 more specific route exists in vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
// unique route in vpn table
err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name())
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
filter := createBPFFilter(tc.destination)
handle := startPacketCapture(t, tc.captureInterface, filter)
sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer)
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
packet, err := packetSource.NextPacket()
require.NoError(t, err)
verifyPacket(t, packet, tc.packetExpectation)
})
}
}
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
t.Helper()
ipLayer := packet.Layer(layers.LayerTypeIPv4)
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
ip, ok := ipLayer.(*layers.IPv4)
require.True(t, ok, "Failed to cast to IPv4 layer")
// Convert both source and destination IP addresses to 16-byte representation
expectedSrcIP := exp.SrcIP.To16()
actualSrcIP := ip.SrcIP.To16()
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
expectedDstIP := exp.DstIP.To16()
actualDstIP := ip.DstIP.To16()
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
if exp.UDP {
udpLayer := packet.Layer(layers.LayerTypeUDP)
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
udp, ok := udpLayer.(*layers.UDP)
require.True(t, ok, "Failed to cast to UDP layer")
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
}
if exp.TCP {
tcpLayer := packet.Layer(layers.LayerTypeTCP)
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
tcp, ok := tcpLayer.(*layers.TCP)
require.True(t, ok, "Failed to cast to TCP layer")
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
}
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy {
t.Helper()
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
err := netlink.LinkDel(dummy)
if err != nil && !errors.Is(err, syscall.EINVAL) {
t.Logf("Failed to delete dummy interface: %v", err)
}
err = netlink.LinkAdd(dummy)
require.NoError(t, err)
err = netlink.LinkSetUp(dummy)
require.NoError(t, err)
if ipAddressCIDR != "" {
addr, err := netlink.ParseAddr(ipAddressCIDR)
require.NoError(t, err)
err = netlink.AddrAdd(dummy, addr)
require.NoError(t, err)
}
return dummy
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) {
t.Helper()
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
require.NoError(t, err)
if dstIPNet.String() == "0.0.0.0/0" {
gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4)
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
// Handle existing routes with metric 0
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
if err == nil {
t.Cleanup(func() {
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0})
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
})
} else if !errors.Is(err, syscall.ESRCH) {
t.Logf("Failed to delete route: %v", err)
}
}
route := &netlink.Route{
Dst: dstIPNet,
Gw: gw,
LinkIndex: linkIndex,
}
err = netlink.RouteDel(route)
if err != nil && !errors.Is(err, syscall.ESRCH) {
t.Logf("Failed to delete route: %v", err)
}
err = netlink.RouteAdd(route)
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
}
// fetchOriginalGateway returns the original gateway IP address and the interface index.
func fetchOriginalGateway(family int) (net.IP, int, error) {
routes, err := netlink.RouteList(nil, family)
if err != nil {
return nil, 0, err
}
for _, route := range routes {
if route.Dst == nil {
return route.Gw, route.LinkIndex, nil
}
}
return nil, 0, fmt.Errorf("default route not found")
}
func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index)
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index)
t.Cleanup(func() {
err := netlink.LinkDel(defaultDummy)
assert.NoError(t, err)
err = netlink.LinkDel(otherDummy)
assert.NoError(t, err)
})
return defaultDummy.Name, otherDummy.Name
}
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
t.Helper()
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet(nil)
require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create()
require.NoError(t, err, "should create testing WireGuard interface")
t.Cleanup(func() {
wgInterface.Close()
})
return wgInterface
}
func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) {
t.Helper()
defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t)
wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820)
t.Cleanup(func() {
assert.NoError(t, wgIface.Close())
})
err := setupRouting()
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
})
return wgIface, defaultDummy, otherDummy
}
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
t.Helper()
inactive, err := pcap.NewInactiveHandle(intf)
require.NoError(t, err, "Failed to create inactive pcap handle")
defer inactive.CleanUp()
err = inactive.SetSnapLen(1600)
require.NoError(t, err, "Failed to set snap length on inactive handle")
err = inactive.SetTimeout(time.Second * 10)
require.NoError(t, err, "Failed to set timeout on inactive handle")
err = inactive.SetImmediateMode(true)
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
handle, err := inactive.Activate()
require.NoError(t, err, "Failed to activate pcap handle")
t.Cleanup(handle.Close)
err = handle.SetBPFFilter(filter)
require.NoError(t, err, "Failed to set BPF filter")
return handle
}
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) {
t.Helper()
if dialer == nil {
dialer = &net.Dialer{}
}
if sourcePort != 0 {
localUDPAddr := &net.UDPAddr{
IP: net.IPv4zero,
Port: sourcePort,
}
dialer.LocalAddr = localUDPAddr
}
msg := new(dns.Msg)
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = []dns.Question{
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
conn, err := dialer.Dial("udp", destination)
require.NoError(t, err, "Failed to dial UDP")
defer conn.Close()
data, err := msg.Pack()
require.NoError(t, err, "Failed to pack DNS message")
_, err = conn.Write(data)
if err != nil {
if strings.Contains(err.Error(), "required key not available") {
t.Logf("Ignoring WireGuard key error: %v", err)
return
}
t.Fatalf("Failed to send DNS query: %v", err)
}
}
func createBPFFilter(destination string) string {
host, port, err := net.SplitHostPort(destination)
if err != nil {
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
}
return "udp"
}
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}