mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-11 09:21:49 +01:00
2475473227
All routes are now installed in a custom netbird routing table. Management and wireguard traffic is now marked with a custom fwmark. When the mark is present the traffic is routed via the main routing table, bypassing the VPN. When the mark is absent the traffic is routed via the netbird routing table, if: - there's no match in the main routing table - it would match the default route in the routing table IPv6 traffic is blocked when a default route IPv4 route is configured to avoid leakage.
470 lines
13 KiB
Go
470 lines
13 KiB
Go
//go:build !android
|
|
|
|
package routemanager
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"strings"
|
|
"syscall"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gopacket/gopacket"
|
|
"github.com/gopacket/gopacket/layers"
|
|
"github.com/gopacket/gopacket/pcap"
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/vishvananda/netlink"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
"github.com/netbirdio/netbird/iface"
|
|
nbnet "github.com/netbirdio/netbird/util/net"
|
|
)
|
|
|
|
type PacketExpectation struct {
|
|
SrcIP net.IP
|
|
DstIP net.IP
|
|
SrcPort int
|
|
DstPort int
|
|
UDP bool
|
|
TCP bool
|
|
}
|
|
|
|
func TestEntryExists(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
|
|
|
content := []string{
|
|
"1000 reserved",
|
|
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
|
"9999 other_table",
|
|
}
|
|
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
|
|
|
file, err := os.Open(tempFilePath)
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
assert.NoError(t, file.Close())
|
|
}()
|
|
|
|
tests := []struct {
|
|
name string
|
|
id int
|
|
shouldExist bool
|
|
err error
|
|
}{
|
|
{
|
|
name: "ExistsWithNetbirdPrefix",
|
|
id: 7120,
|
|
shouldExist: true,
|
|
err: nil,
|
|
},
|
|
{
|
|
name: "ExistsWithDifferentName",
|
|
id: 1000,
|
|
shouldExist: true,
|
|
err: ErrTableIDExists,
|
|
},
|
|
{
|
|
name: "DoesNotExist",
|
|
id: 1234,
|
|
shouldExist: false,
|
|
err: nil,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
exists, err := entryExists(file, tc.id)
|
|
if tc.err != nil {
|
|
assert.ErrorIs(t, err, tc.err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
assert.Equal(t, tc.shouldExist, exists)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRoutingWithTables(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
destination string
|
|
captureInterface string
|
|
dialer *net.Dialer
|
|
packetExpectation PacketExpectation
|
|
}{
|
|
{
|
|
name: "To external host without fwmark via vpn",
|
|
destination: "192.0.2.1:53",
|
|
captureInterface: "wgtest0",
|
|
dialer: &net.Dialer{},
|
|
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
|
},
|
|
{
|
|
name: "To external host with fwmark via physical interface",
|
|
destination: "192.0.2.1:53",
|
|
captureInterface: "dummyext0",
|
|
dialer: nbnet.NewDialer(),
|
|
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
|
},
|
|
|
|
{
|
|
name: "To duplicate internal route with fwmark via physical interface",
|
|
destination: "10.0.0.1:53",
|
|
captureInterface: "dummyint0",
|
|
dialer: nbnet.NewDialer(),
|
|
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
|
|
},
|
|
{
|
|
name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence
|
|
destination: "10.0.0.1:53",
|
|
captureInterface: "dummyint0",
|
|
dialer: &net.Dialer{},
|
|
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53),
|
|
},
|
|
|
|
{
|
|
name: "To unique vpn route with fwmark via physical interface",
|
|
destination: "172.16.0.1:53",
|
|
captureInterface: "dummyext0",
|
|
dialer: nbnet.NewDialer(),
|
|
packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53),
|
|
},
|
|
{
|
|
name: "To unique vpn route without fwmark via vpn",
|
|
destination: "172.16.0.1:53",
|
|
captureInterface: "wgtest0",
|
|
dialer: &net.Dialer{},
|
|
packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53),
|
|
},
|
|
|
|
{
|
|
name: "To more specific route without fwmark via vpn interface",
|
|
destination: "10.10.0.1:53",
|
|
captureInterface: "dummyint0",
|
|
dialer: &net.Dialer{},
|
|
packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53),
|
|
},
|
|
|
|
{
|
|
name: "To more specific route (local) without fwmark via physical interface",
|
|
destination: "127.0.10.1:53",
|
|
captureInterface: "lo",
|
|
dialer: &net.Dialer{},
|
|
packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
wgIface, _, _ := setupTestEnv(t)
|
|
|
|
// default route exists in main table and vpn table
|
|
err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name())
|
|
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
|
|
|
// 10.0.0.0/8 route exists in main table and vpn table
|
|
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name())
|
|
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
|
|
|
// 10.10.0.0/24 more specific route exists in vpn table
|
|
err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name())
|
|
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
|
|
|
// 127.0.10.0/24 more specific route exists in vpn table
|
|
err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name())
|
|
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
|
|
|
// unique route in vpn table
|
|
err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name())
|
|
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
|
|
|
filter := createBPFFilter(tc.destination)
|
|
handle := startPacketCapture(t, tc.captureInterface, filter)
|
|
|
|
sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer)
|
|
|
|
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
|
packet, err := packetSource.NextPacket()
|
|
require.NoError(t, err)
|
|
|
|
verifyPacket(t, packet, tc.packetExpectation)
|
|
})
|
|
}
|
|
}
|
|
|
|
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
|
|
t.Helper()
|
|
|
|
ipLayer := packet.Layer(layers.LayerTypeIPv4)
|
|
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
|
|
|
|
ip, ok := ipLayer.(*layers.IPv4)
|
|
require.True(t, ok, "Failed to cast to IPv4 layer")
|
|
|
|
// Convert both source and destination IP addresses to 16-byte representation
|
|
expectedSrcIP := exp.SrcIP.To16()
|
|
actualSrcIP := ip.SrcIP.To16()
|
|
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
|
|
|
|
expectedDstIP := exp.DstIP.To16()
|
|
actualDstIP := ip.DstIP.To16()
|
|
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
|
|
|
|
if exp.UDP {
|
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
|
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
|
|
|
|
udp, ok := udpLayer.(*layers.UDP)
|
|
require.True(t, ok, "Failed to cast to UDP layer")
|
|
|
|
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
|
|
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
|
|
}
|
|
|
|
if exp.TCP {
|
|
tcpLayer := packet.Layer(layers.LayerTypeTCP)
|
|
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
|
|
|
|
tcp, ok := tcpLayer.(*layers.TCP)
|
|
require.True(t, ok, "Failed to cast to TCP layer")
|
|
|
|
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
|
|
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
|
|
}
|
|
|
|
}
|
|
|
|
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy {
|
|
t.Helper()
|
|
|
|
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
|
|
err := netlink.LinkDel(dummy)
|
|
if err != nil && !errors.Is(err, syscall.EINVAL) {
|
|
t.Logf("Failed to delete dummy interface: %v", err)
|
|
}
|
|
|
|
err = netlink.LinkAdd(dummy)
|
|
require.NoError(t, err)
|
|
|
|
err = netlink.LinkSetUp(dummy)
|
|
require.NoError(t, err)
|
|
|
|
if ipAddressCIDR != "" {
|
|
addr, err := netlink.ParseAddr(ipAddressCIDR)
|
|
require.NoError(t, err)
|
|
err = netlink.AddrAdd(dummy, addr)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
return dummy
|
|
}
|
|
|
|
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) {
|
|
t.Helper()
|
|
|
|
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
|
|
require.NoError(t, err)
|
|
|
|
if dstIPNet.String() == "0.0.0.0/0" {
|
|
gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4)
|
|
if err != nil {
|
|
t.Logf("Failed to fetch original gateway: %v", err)
|
|
}
|
|
|
|
// Handle existing routes with metric 0
|
|
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
|
|
if err == nil {
|
|
t.Cleanup(func() {
|
|
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0})
|
|
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
|
t.Fatalf("Failed to add route: %v", err)
|
|
}
|
|
})
|
|
} else if !errors.Is(err, syscall.ESRCH) {
|
|
t.Logf("Failed to delete route: %v", err)
|
|
}
|
|
}
|
|
|
|
route := &netlink.Route{
|
|
Dst: dstIPNet,
|
|
Gw: gw,
|
|
LinkIndex: linkIndex,
|
|
}
|
|
err = netlink.RouteDel(route)
|
|
if err != nil && !errors.Is(err, syscall.ESRCH) {
|
|
t.Logf("Failed to delete route: %v", err)
|
|
}
|
|
|
|
err = netlink.RouteAdd(route)
|
|
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
|
t.Fatalf("Failed to add route: %v", err)
|
|
}
|
|
}
|
|
|
|
// fetchOriginalGateway returns the original gateway IP address and the interface index.
|
|
func fetchOriginalGateway(family int) (net.IP, int, error) {
|
|
routes, err := netlink.RouteList(nil, family)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
for _, route := range routes {
|
|
if route.Dst == nil {
|
|
return route.Gw, route.LinkIndex, nil
|
|
}
|
|
}
|
|
|
|
return nil, 0, fmt.Errorf("default route not found")
|
|
}
|
|
|
|
func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) {
|
|
t.Helper()
|
|
|
|
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
|
|
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index)
|
|
|
|
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
|
|
addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index)
|
|
|
|
t.Cleanup(func() {
|
|
err := netlink.LinkDel(defaultDummy)
|
|
assert.NoError(t, err)
|
|
err = netlink.LinkDel(otherDummy)
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
return defaultDummy.Name, otherDummy.Name
|
|
}
|
|
|
|
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
|
t.Helper()
|
|
|
|
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
|
require.NoError(t, err)
|
|
|
|
newNet, err := stdnet.NewNet(nil)
|
|
require.NoError(t, err)
|
|
|
|
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
|
require.NoError(t, err, "should create testing WireGuard interface")
|
|
|
|
err = wgInterface.Create()
|
|
require.NoError(t, err, "should create testing WireGuard interface")
|
|
|
|
t.Cleanup(func() {
|
|
wgInterface.Close()
|
|
})
|
|
|
|
return wgInterface
|
|
}
|
|
|
|
func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) {
|
|
t.Helper()
|
|
|
|
defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t)
|
|
|
|
wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820)
|
|
t.Cleanup(func() {
|
|
assert.NoError(t, wgIface.Close())
|
|
})
|
|
|
|
err := setupRouting()
|
|
require.NoError(t, err, "setupRouting should not return err")
|
|
t.Cleanup(func() {
|
|
assert.NoError(t, cleanupRouting())
|
|
})
|
|
|
|
return wgIface, defaultDummy, otherDummy
|
|
}
|
|
|
|
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
|
t.Helper()
|
|
|
|
inactive, err := pcap.NewInactiveHandle(intf)
|
|
require.NoError(t, err, "Failed to create inactive pcap handle")
|
|
defer inactive.CleanUp()
|
|
|
|
err = inactive.SetSnapLen(1600)
|
|
require.NoError(t, err, "Failed to set snap length on inactive handle")
|
|
|
|
err = inactive.SetTimeout(time.Second * 10)
|
|
require.NoError(t, err, "Failed to set timeout on inactive handle")
|
|
|
|
err = inactive.SetImmediateMode(true)
|
|
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
|
|
|
|
handle, err := inactive.Activate()
|
|
require.NoError(t, err, "Failed to activate pcap handle")
|
|
t.Cleanup(handle.Close)
|
|
|
|
err = handle.SetBPFFilter(filter)
|
|
require.NoError(t, err, "Failed to set BPF filter")
|
|
|
|
return handle
|
|
}
|
|
|
|
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) {
|
|
t.Helper()
|
|
|
|
if dialer == nil {
|
|
dialer = &net.Dialer{}
|
|
}
|
|
|
|
if sourcePort != 0 {
|
|
localUDPAddr := &net.UDPAddr{
|
|
IP: net.IPv4zero,
|
|
Port: sourcePort,
|
|
}
|
|
dialer.LocalAddr = localUDPAddr
|
|
}
|
|
|
|
msg := new(dns.Msg)
|
|
msg.Id = dns.Id()
|
|
msg.RecursionDesired = true
|
|
msg.Question = []dns.Question{
|
|
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
|
}
|
|
|
|
conn, err := dialer.Dial("udp", destination)
|
|
require.NoError(t, err, "Failed to dial UDP")
|
|
defer conn.Close()
|
|
|
|
data, err := msg.Pack()
|
|
require.NoError(t, err, "Failed to pack DNS message")
|
|
|
|
_, err = conn.Write(data)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "required key not available") {
|
|
t.Logf("Ignoring WireGuard key error: %v", err)
|
|
return
|
|
}
|
|
t.Fatalf("Failed to send DNS query: %v", err)
|
|
}
|
|
}
|
|
|
|
func createBPFFilter(destination string) string {
|
|
host, port, err := net.SplitHostPort(destination)
|
|
if err != nil {
|
|
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
|
|
}
|
|
return "udp"
|
|
}
|
|
|
|
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
|
return PacketExpectation{
|
|
SrcIP: net.ParseIP(srcIP),
|
|
DstIP: net.ParseIP(dstIP),
|
|
SrcPort: srcPort,
|
|
DstPort: dstPort,
|
|
UDP: true,
|
|
}
|
|
}
|