netbird/client/internal/routemanager/systemops_windows.go

222 lines
5.3 KiB
Go
Raw Normal View History

//go:build windows
package routemanager
2023-06-09 19:15:39 +02:00
import (
"fmt"
2023-06-12 11:43:18 +02:00
"net"
2023-06-09 19:15:39 +02:00
"net/netip"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"time"
2023-06-09 19:15:39 +02:00
log "github.com/sirupsen/logrus"
2023-06-09 19:17:26 +02:00
"github.com/yusufpapurcu/wmi"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
2023-06-09 19:15:39 +02:00
)
type MSFT_NetRoute struct {
DestinationPrefix string
NextHop string
InterfaceIndex int32
InterfaceAlias string
AddressFamily uint16
}
type Route struct {
Destination netip.Prefix
Nexthop netip.Addr
Interface *net.Interface
}
type MSFT_NetNeighbor struct {
IPAddress string
LinkLayerAddress string
State uint8
AddressFamily uint16
InterfaceIndex uint32
InterfaceAlias string
}
type Neighbor struct {
IPAddress netip.Addr
LinkLayerAddress string
State uint8
AddressFamily uint16
InterfaceIndex uint32
InterfaceAlias string
2023-06-09 19:15:39 +02:00
}
var prefixList []netip.Prefix
var lastUpdate time.Time
var mux = sync.Mutex{}
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
}
func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
}
func getRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
return prefixList, nil
}
routes, err := GetRoutes()
2023-06-09 19:15:39 +02:00
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
2023-06-09 19:15:39 +02:00
}
prefixList = nil
2023-06-09 19:15:39 +02:00
for _, route := range routes {
prefixList = append(prefixList, route.Destination)
}
lastUpdate = time.Now()
return prefixList, nil
}
func GetRoutes() ([]Route, error) {
var entries []MSFT_NetRoute
query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
var routes []Route
for _, entry := range entries {
dest, err := netip.ParsePrefix(entry.DestinationPrefix)
if err != nil {
log.Warnf("Unable to parse route destination %s: %v", entry.DestinationPrefix, err)
continue
}
nexthop, err := netip.ParseAddr(entry.NextHop)
if err != nil {
log.Warnf("Unable to parse route next hop %s: %v", entry.NextHop, err)
continue
}
var intf *net.Interface
if entry.InterfaceIndex != 0 {
intf = &net.Interface{
Index: int(entry.InterfaceIndex),
Name: entry.InterfaceAlias,
}
2023-06-12 11:43:18 +02:00
}
routes = append(routes, Route{
Destination: dest,
Nexthop: nexthop,
Interface: intf,
})
2023-06-09 19:15:39 +02:00
}
return routes, nil
}
func GetNeighbors() ([]Neighbor, error) {
var entries []MSFT_NetNeighbor
query := `SELECT IPAddress, LinkLayerAddress, State, AddressFamily, InterfaceIndex, InterfaceAlias FROM MSFT_NetNeighbor`
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
return nil, fmt.Errorf("failed to query MSFT_NetNeighbor: %w", err)
}
var neighbors []Neighbor
for _, entry := range entries {
addr, err := netip.ParseAddr(entry.IPAddress)
if err != nil {
log.Warnf("Unable to parse neighbor IP address %s: %v", entry.IPAddress, err)
continue
}
neighbors = append(neighbors, Neighbor{
IPAddress: addr,
LinkLayerAddress: entry.LinkLayerAddress,
State: entry.State,
AddressFamily: entry.AddressFamily,
InterfaceIndex: entry.InterfaceIndex,
InterfaceAlias: entry.InterfaceAlias,
})
}
return neighbors, nil
2023-06-09 19:15:39 +02:00
}
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
args := []string{"add", prefix.String()}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
addr = "::"
}
args = append(args, addr)
}
if intf != nil {
args = append(args, "if", strconv.Itoa(intf.Index))
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
}
return nil
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
if nexthop.Zone() != "" && intf == nil {
zone, err := strconv.Atoi(nexthop.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
intf = &net.Interface{Index: zone}
nexthop.WithZone("")
}
return addRouteCmd(prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
args = append(args, nexthop.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}