mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 19:00:50 +01:00
290 lines
9.0 KiB
Go
290 lines
9.0 KiB
Go
|
package routemanager
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"os/exec"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
|
||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||
|
)
|
||
|
|
||
|
var expectedExtInt = "Ethernet1"
|
||
|
|
||
|
type RouteInfo struct {
|
||
|
NextHop string `json:"nexthop"`
|
||
|
InterfaceAlias string `json:"interfacealias"`
|
||
|
RouteMetric int `json:"routemetric"`
|
||
|
}
|
||
|
|
||
|
type FindNetRouteOutput struct {
|
||
|
IPAddress string `json:"IPAddress"`
|
||
|
InterfaceIndex int `json:"InterfaceIndex"`
|
||
|
InterfaceAlias string `json:"InterfaceAlias"`
|
||
|
AddressFamily int `json:"AddressFamily"`
|
||
|
NextHop string `json:"NextHop"`
|
||
|
DestinationPrefix string `json:"DestinationPrefix"`
|
||
|
}
|
||
|
|
||
|
type testCase struct {
|
||
|
name string
|
||
|
destination string
|
||
|
expectedSourceIP string
|
||
|
expectedDestPrefix string
|
||
|
expectedNextHop string
|
||
|
expectedInterface string
|
||
|
dialer dialer
|
||
|
}
|
||
|
|
||
|
var expectedVPNint = "wgtest0"
|
||
|
|
||
|
var testCases = []testCase{
|
||
|
{
|
||
|
name: "To external host without custom dialer via vpn",
|
||
|
destination: "192.0.2.1:53",
|
||
|
expectedSourceIP: "100.64.0.1",
|
||
|
expectedDestPrefix: "128.0.0.0/1",
|
||
|
expectedNextHop: "0.0.0.0",
|
||
|
expectedInterface: "wgtest0",
|
||
|
dialer: &net.Dialer{},
|
||
|
},
|
||
|
{
|
||
|
name: "To external host with custom dialer via physical interface",
|
||
|
destination: "192.0.2.1:53",
|
||
|
expectedDestPrefix: "192.0.2.1/32",
|
||
|
expectedInterface: expectedExtInt,
|
||
|
dialer: nbnet.NewDialer(),
|
||
|
},
|
||
|
|
||
|
{
|
||
|
name: "To duplicate internal route with custom dialer via physical interface",
|
||
|
destination: "10.0.0.2:53",
|
||
|
expectedDestPrefix: "10.0.0.2/32",
|
||
|
expectedInterface: expectedExtInt,
|
||
|
dialer: nbnet.NewDialer(),
|
||
|
},
|
||
|
{
|
||
|
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||
|
destination: "10.0.0.2:53",
|
||
|
expectedSourceIP: "10.0.0.1",
|
||
|
expectedDestPrefix: "10.0.0.0/8",
|
||
|
expectedNextHop: "0.0.0.0",
|
||
|
expectedInterface: "Loopback Pseudo-Interface 1",
|
||
|
dialer: &net.Dialer{},
|
||
|
},
|
||
|
|
||
|
{
|
||
|
name: "To unique vpn route with custom dialer via physical interface",
|
||
|
destination: "172.16.0.2:53",
|
||
|
expectedDestPrefix: "172.16.0.2/32",
|
||
|
expectedInterface: expectedExtInt,
|
||
|
dialer: nbnet.NewDialer(),
|
||
|
},
|
||
|
{
|
||
|
name: "To unique vpn route without custom dialer via vpn",
|
||
|
destination: "172.16.0.2:53",
|
||
|
expectedSourceIP: "100.64.0.1",
|
||
|
expectedDestPrefix: "172.16.0.0/12",
|
||
|
expectedNextHop: "0.0.0.0",
|
||
|
expectedInterface: "wgtest0",
|
||
|
dialer: &net.Dialer{},
|
||
|
},
|
||
|
|
||
|
{
|
||
|
name: "To more specific route without custom dialer via vpn interface",
|
||
|
destination: "10.10.0.2:53",
|
||
|
expectedSourceIP: "100.64.0.1",
|
||
|
expectedDestPrefix: "10.10.0.0/24",
|
||
|
expectedNextHop: "0.0.0.0",
|
||
|
expectedInterface: "wgtest0",
|
||
|
dialer: &net.Dialer{},
|
||
|
},
|
||
|
|
||
|
{
|
||
|
name: "To more specific route (local) without custom dialer via physical interface",
|
||
|
destination: "127.0.10.2:53",
|
||
|
expectedSourceIP: "10.0.0.1",
|
||
|
expectedDestPrefix: "127.0.0.0/8",
|
||
|
expectedNextHop: "0.0.0.0",
|
||
|
expectedInterface: "Loopback Pseudo-Interface 1",
|
||
|
dialer: &net.Dialer{},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
func TestRouting(t *testing.T) {
|
||
|
for _, tc := range testCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
setupTestEnv(t)
|
||
|
|
||
|
route, err := fetchOriginalGateway()
|
||
|
require.NoError(t, err, "Failed to fetch original gateway")
|
||
|
ip, err := fetchInterfaceIP(route.InterfaceAlias)
|
||
|
require.NoError(t, err, "Failed to fetch interface IP")
|
||
|
|
||
|
output := testRoute(t, tc.destination, tc.dialer)
|
||
|
if tc.expectedInterface == expectedExtInt {
|
||
|
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
|
||
|
} else {
|
||
|
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// fetchInterfaceIP fetches the IPv4 address of the specified interface.
|
||
|
func fetchInterfaceIP(interfaceAlias string) (string, error) {
|
||
|
script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias)
|
||
|
out, err := exec.Command("powershell", "-Command", script).Output()
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err)
|
||
|
}
|
||
|
|
||
|
ip := strings.TrimSpace(string(out))
|
||
|
return ip, nil
|
||
|
}
|
||
|
|
||
|
func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput {
|
||
|
t.Helper()
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||
|
defer cancel()
|
||
|
|
||
|
conn, err := dialer.DialContext(ctx, "udp", destination)
|
||
|
require.NoError(t, err, "Failed to dial destination")
|
||
|
defer func() {
|
||
|
err := conn.Close()
|
||
|
assert.NoError(t, err, "Failed to close connection")
|
||
|
}()
|
||
|
|
||
|
host, _, err := net.SplitHostPort(destination)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
|
||
|
|
||
|
out, err := exec.Command("powershell", "-Command", script).Output()
|
||
|
require.NoError(t, err, "Failed to execute Find-NetRoute")
|
||
|
|
||
|
var outputs []FindNetRouteOutput
|
||
|
err = json.Unmarshal(out, &outputs)
|
||
|
require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute")
|
||
|
|
||
|
require.Greater(t, len(outputs), 0, "No route found for destination")
|
||
|
combinedOutput := combineOutputs(outputs)
|
||
|
|
||
|
return combinedOutput
|
||
|
}
|
||
|
|
||
|
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||
|
t.Helper()
|
||
|
|
||
|
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
|
||
|
require.NoError(t, err)
|
||
|
subnetMaskSize, _ := ipNet.Mask.Size()
|
||
|
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
|
||
|
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
||
|
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
|
||
|
|
||
|
// Wait for the IP address to be applied
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||
|
defer cancel()
|
||
|
err = waitForIPAddress(ctx, interfaceName, ip.String())
|
||
|
require.NoError(t, err, "IP address not applied within timeout")
|
||
|
|
||
|
t.Cleanup(func() {
|
||
|
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
|
||
|
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
||
|
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
|
||
|
})
|
||
|
|
||
|
return interfaceName
|
||
|
}
|
||
|
|
||
|
func fetchOriginalGateway() (*RouteInfo, error) {
|
||
|
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||
|
output, err := cmd.CombinedOutput()
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
|
||
|
}
|
||
|
|
||
|
var routeInfo RouteInfo
|
||
|
err = json.Unmarshal(output, &routeInfo)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to parse JSON output: %w", err)
|
||
|
}
|
||
|
|
||
|
return &routeInfo, nil
|
||
|
}
|
||
|
|
||
|
func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) {
|
||
|
t.Helper()
|
||
|
|
||
|
assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch")
|
||
|
assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch")
|
||
|
assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch")
|
||
|
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
|
||
|
}
|
||
|
|
||
|
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
|
||
|
ticker := time.NewTicker(1 * time.Second)
|
||
|
defer ticker.Stop()
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return ctx.Err()
|
||
|
case <-ticker.C:
|
||
|
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||
|
for _, ip := range ipAddresses {
|
||
|
if strings.TrimSpace(ip) == expectedIPAddress {
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
|
||
|
var combined FindNetRouteOutput
|
||
|
|
||
|
for _, output := range outputs {
|
||
|
if output.IPAddress != "" {
|
||
|
combined.IPAddress = output.IPAddress
|
||
|
}
|
||
|
if output.InterfaceIndex != 0 {
|
||
|
combined.InterfaceIndex = output.InterfaceIndex
|
||
|
}
|
||
|
if output.InterfaceAlias != "" {
|
||
|
combined.InterfaceAlias = output.InterfaceAlias
|
||
|
}
|
||
|
if output.AddressFamily != 0 {
|
||
|
combined.AddressFamily = output.AddressFamily
|
||
|
}
|
||
|
if output.NextHop != "" {
|
||
|
combined.NextHop = output.NextHop
|
||
|
}
|
||
|
if output.DestinationPrefix != "" {
|
||
|
combined.DestinationPrefix = output.DestinationPrefix
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return &combined
|
||
|
}
|
||
|
|
||
|
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||
|
t.Helper()
|
||
|
|
||
|
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
|
||
|
}
|