diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go index 29df7ea7f..51135a729 100644 --- a/client/internal/networkmonitor/monitor_bsd.go +++ b/client/internal/networkmonitor/monitor_bsd.go @@ -65,7 +65,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca continue } - if !route.Dst.Addr().IsUnspecified() { + if route.Dst.Bits() != 0 { continue } diff --git a/client/internal/networkmonitor/monitor_generic.go b/client/internal/networkmonitor/monitor_generic.go index f5cc19473..19648edba 100644 --- a/client/internal/networkmonitor/monitor_generic.go +++ b/client/internal/networkmonitor/monitor_generic.go @@ -59,7 +59,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error // recover in case sys ops panic defer func() { if r := recover(); r != nil { - err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack())) + err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack()) } }() diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/monitor_windows.go index 308b2aa45..cd48c269d 100644 --- a/client/internal/networkmonitor/monitor_windows.go +++ b/client/internal/networkmonitor/monitor_windows.go @@ -3,252 +3,73 @@ package networkmonitor import ( "context" "fmt" - "net" - "net/netip" "strings" - "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -const ( - unreachable = 0 - incomplete = 1 - probe = 2 - delay = 3 - stale = 4 - reachable = 5 - permanent = 6 - tbd = 7 -) - -const interval = 10 * time.Second - func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { - var neighborv4, neighborv6 *systemops.Neighbor - { - initialNeighbors, err := getNeighbors() - if err != nil { - return fmt.Errorf("get neighbors: %w", err) - } - - neighborv4 = assignNeighbor(nexthopv4, initialNeighbors) - neighborv6 = assignNeighbor(nexthopv6, initialNeighbors) + routeMonitor, err := systemops.NewRouteMonitor(ctx) + if err != nil { + return fmt.Errorf("failed to create route monitor: %w", err) } - log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6) - - ticker := time.NewTicker(interval) - defer ticker.Stop() + defer func() { + if err := routeMonitor.Stop(); err != nil { + log.Errorf("Network monitor: failed to stop route monitor: %v", err) + } + }() for { select { case <-ctx.Done(): return ErrStopped - case <-ticker.C: - if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) { - go callback() - return nil + case route := <-routeMonitor.RouteUpdates(): + if route.Destination.Bits() != 0 { + continue + } + + if routeChanged(route, nexthopv4, nexthopv6, callback) { + break } } } } -func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor { - if n, ok := initialNeighbors[nexthop.IP]; ok && - n.State != unreachable && - n.State != incomplete && - n.State != tbd { - return &n - } - return nil -} - -func changed( - nexthopv4 systemops.Nexthop, - neighborv4 *systemops.Neighbor, - nexthopv6 systemops.Nexthop, - neighborv6 *systemops.Neighbor, -) bool { - neighbors, err := getNeighbors() - if err != nil { - log.Errorf("network monitor: error fetching current neighbors: %v", err) - return false - } - if neighborChanged(nexthopv4, neighborv4, neighbors) || neighborChanged(nexthopv6, neighborv6, neighbors) { - return true - } - - routes, err := getRoutes() - if err != nil { - log.Errorf("network monitor: error fetching current routes: %v", err) - return false - } - - if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) { - return true - } - - return false -} - -// routeChanged checks if the default routes still point to our nexthop/interface -func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool { - if !nexthop.IP.IsValid() { - return false - } - - if isSoftInterface(nexthop.Intf.Name) { - log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name) - return false - } - - unspec := getUnspecifiedPrefix(nexthop.IP) - defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec) - - log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n")) - - if !foundMatchingRoute { - logRouteChange(nexthop.IP, intf) - return true - } - - return false -} - -func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix { - if ip.Is6() { - return netip.PrefixFrom(netip.IPv6Unspecified(), 0) - } - return netip.PrefixFrom(netip.IPv4Unspecified(), 0) -} - -func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) { - var defaultRoutes []string - foundMatchingRoute := false - - for _, r := range routes { - if r.Destination == unspec { - routeInfo := formatRouteInfo(r) - defaultRoutes = append(defaultRoutes, routeInfo) - - if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 { - foundMatchingRoute = true - log.Debugf("network monitor: found matching default route: %s", routeInfo) - } +func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool { + intf := "" + if route.Interface != nil { + intf = route.Interface.Name + if isSoftInterface(intf) { + log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf) + return false } } - return defaultRoutes, foundMatchingRoute -} - -func formatRouteInfo(r systemops.Route) string { - newIntf := "" - if r.Interface != nil { - newIntf = r.Interface.Name - } - return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf) -} - -func logRouteChange(ip netip.Addr, intf *net.Interface) { - oldIntf := "" - if intf != nil { - oldIntf = intf.Name - } - log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf) -} - -func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool { - if neighbor == nil { - return false - } - - // TODO: consider non-local nexthops, e.g. on point-to-point interfaces - if n, ok := neighbors[nexthop.IP]; ok { - if n.State == unreachable || n.State == incomplete { - log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State)) - return true - } else if n.InterfaceIndex != neighbor.InterfaceIndex { - log.Infof( - "network monitor: neighbor %s (%s) changed interface from '%s' (%d) to '%s' (%d): %s", - neighbor.IPAddress, - neighbor.LinkLayerAddress, - neighbor.InterfaceAlias, - neighbor.InterfaceIndex, - n.InterfaceAlias, - n.InterfaceIndex, - stateFromInt(n.State), - ) + switch route.Type { + case systemops.RouteModified: + // TODO: get routing table to figure out if our route is affected for modified routes + log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf) + go callback() + return true + case systemops.RouteAdded: + if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP { + log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf) + go callback() + return true + } + case systemops.RouteDeleted: + if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { + log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf) + go callback() return true } - } else { - log.Infof("network monitor: neighbor %s (%s) is gone", neighbor.IPAddress, neighbor.LinkLayerAddress) - return true } return false } -func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) { - entries, err := systemops.GetNeighbors() - if err != nil { - return nil, fmt.Errorf("get neighbors: %w", err) - } - - neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries)) - for _, entry := range entries { - neighbours[entry.IPAddress] = entry - } - - return neighbours, nil -} - -func getRoutes() ([]systemops.Route, error) { - entries, err := systemops.GetRoutes() - if err != nil { - return nil, fmt.Errorf("get routes: %w", err) - } - - return entries, nil -} - -func stateFromInt(state uint8) string { - switch state { - case unreachable: - return "unreachable" - case incomplete: - return "incomplete" - case probe: - return "probe" - case delay: - return "delay" - case stale: - return "stale" - case reachable: - return "reachable" - case permanent: - return "permanent" - case tbd: - return "tbd" - default: - return "unknown" - } -} - -func compareIntf(a, b *net.Interface) int { - switch { - case a == nil && b == nil: - return 0 - case a == nil: - return -1 - case b == nil: - return 1 - default: - return a.Index - b.Index - } -} - func isSoftInterface(name string) bool { return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo") } diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 0d3630cb8..3f756788e 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -3,6 +3,8 @@ package systemops import ( + "context" + "encoding/binary" "fmt" "net" "net/netip" @@ -11,15 +13,43 @@ import ( "strconv" "strings" "sync" + "syscall" "time" + "unsafe" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" + "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/firewall/uspfilter" nbnet "github.com/netbirdio/netbird/util/net" ) +type RouteUpdateType int + +// RouteUpdate represents a change in the routing table. +// The interface field contains the index only. +type RouteUpdate struct { + Type RouteUpdateType + Destination netip.Prefix + NextHop netip.Addr + Interface *net.Interface +} + +// RouteMonitor provides a way to monitor changes in the routing table. +type RouteMonitor struct { + updates chan RouteUpdate + handle windows.Handle + done chan struct{} +} + +// Route represents a single routing table entry. +type Route struct { + Destination netip.Prefix + Nexthop netip.Addr + Interface *net.Interface +} + type MSFT_NetRoute struct { DestinationPrefix string NextHop string @@ -28,33 +58,77 @@ type MSFT_NetRoute struct { AddressFamily uint16 } -type Route struct { - Destination netip.Prefix - Nexthop netip.Addr - Interface *net.Interface +// MIB_IPFORWARD_ROW2 is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2 +type MIB_IPFORWARD_ROW2 struct { + InterfaceLuid uint64 + InterfaceIndex uint32 + DestinationPrefix IP_ADDRESS_PREFIX + NextHop SOCKADDR_INET_NEXTHOP + SitePrefixLength uint8 + ValidLifetime uint32 + PreferredLifetime uint32 + Metric uint32 + Protocol uint32 + Loopback uint8 + AutoconfigureAddress uint8 + Publish uint8 + Immortal uint8 + Age uint32 + Origin uint32 } -type MSFT_NetNeighbor struct { - IPAddress string - LinkLayerAddress string - State uint8 - AddressFamily uint16 - InterfaceIndex uint32 - InterfaceAlias string +// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix +type IP_ADDRESS_PREFIX struct { + Prefix SOCKADDR_INET + PrefixLength uint8 } -type Neighbor struct { - IPAddress netip.Addr - LinkLayerAddress string - State uint8 - AddressFamily uint16 - InterfaceIndex uint32 - InterfaceAlias string +// SOCKADDR_INET is defined in https://learn.microsoft.com/en-us/windows/win32/api/ws2ipdef/ns-ws2ipdef-sockaddr_inet +// It represents the union of IPv4 and IPv6 socket addresses +type SOCKADDR_INET struct { + sin6_family int16 + // nolint:unused + sin6_port uint16 + // 4 bytes ipv4 or 4 bytes flowinfo + 16 bytes ipv6 + 4 bytes scope_id + data [24]byte } -var prefixList []netip.Prefix -var lastUpdate time.Time -var mux = sync.Mutex{} +// SOCKADDR_INET_NEXTHOP is the same as SOCKADDR_INET but offset by 2 bytes +type SOCKADDR_INET_NEXTHOP struct { + // nolint:unused + pad [2]byte + sin6_family int16 + // nolint:unused + sin6_port uint16 + // 4 bytes ipv4 or 4 bytes flowinfo + 16 bytes ipv6 + 4 bytes scope_id + data [24]byte +} + +// MIB_NOTIFICATION_TYPE is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ne-netioapi-mib_notification_type +type MIB_NOTIFICATION_TYPE int32 + +var ( + modiphlpapi = windows.NewLazyDLL("iphlpapi.dll") + procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2") + procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2") + + prefixList []netip.Prefix + lastUpdate time.Time + mux sync.Mutex +) + +const ( + MibParemeterModification MIB_NOTIFICATION_TYPE = iota + MibAddInstance + MibDeleteInstance + MibInitialNotification +) + +const ( + RouteModified RouteUpdateType = iota + RouteAdded + RouteDeleted +) func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return r.setupRefCounter(initAddresses) @@ -94,6 +168,155 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro return nil } +// NewRouteMonitor creates and starts a new RouteMonitor. +// It returns a pointer to the RouteMonitor and an error if the monitor couldn't be started. +func NewRouteMonitor(ctx context.Context) (*RouteMonitor, error) { + rm := &RouteMonitor{ + updates: make(chan RouteUpdate, 5), + done: make(chan struct{}), + } + + if err := rm.start(ctx); err != nil { + return nil, err + } + + return rm, nil +} + +func (rm *RouteMonitor) start(ctx context.Context) error { + if ctx.Err() != nil { + return ctx.Err() + } + + callbackPtr := windows.NewCallback(func(callerContext uintptr, row *MIB_IPFORWARD_ROW2, notificationType MIB_NOTIFICATION_TYPE) uintptr { + if ctx.Err() != nil { + return 0 + } + + update, err := rm.parseUpdate(row, notificationType) + if err != nil { + log.Errorf("Failed to parse route update: %v", err) + return 0 + } + + select { + case <-rm.done: + return 0 + case rm.updates <- update: + default: + log.Warn("Route update channel is full, dropping update") + } + return 0 + }) + + var handle windows.Handle + if err := notifyRouteChange2(windows.AF_UNSPEC, callbackPtr, 0, false, &handle); err != nil { + return fmt.Errorf("NotifyRouteChange2 failed: %w", err) + } + + rm.handle = handle + + return nil +} + +func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MIB_NOTIFICATION_TYPE) (RouteUpdate, error) { + // destination prefix, next hop, interface index, interface luid are guaranteed to be there + // GetIpForwardEntry2 is not needed + + var update RouteUpdate + + idx := int(row.InterfaceIndex) + if idx != 0 { + intf, err := net.InterfaceByIndex(idx) + if err != nil { + return update, fmt.Errorf("get interface name: %w", err) + } + + update.Interface = intf + } + + log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface) + dest := parseIPPrefix(row.DestinationPrefix, idx) + if !dest.Addr().IsValid() { + return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row) + } + + nexthop := parseIPNexthop(row.NextHop, idx) + if !nexthop.IsValid() { + return RouteUpdate{}, fmt.Errorf("invalid next hop %v", row) + } + + updateType := RouteModified + switch notificationType { + case MibParemeterModification: + updateType = RouteModified + case MibAddInstance: + updateType = RouteAdded + case MibDeleteInstance: + updateType = RouteDeleted + } + + update.Type = updateType + update.Destination = dest + update.NextHop = nexthop + + return update, nil +} + +// Stop stops the RouteMonitor. +func (rm *RouteMonitor) Stop() error { + if rm.handle != 0 { + if err := cancelMibChangeNotify2(rm.handle); err != nil { + return fmt.Errorf("CancelMibChangeNotify2 failed: %w", err) + } + rm.handle = 0 + } + close(rm.done) + close(rm.updates) + return nil +} + +// RouteUpdates returns a channel that receives RouteUpdate messages. +func (rm *RouteMonitor) RouteUpdates() <-chan RouteUpdate { + return rm.updates +} + +func notifyRouteChange2(family uint32, callback uintptr, callerContext uintptr, initialNotification bool, handle *windows.Handle) error { + var initNotif uint32 + if initialNotification { + initNotif = 1 + } + + r1, _, e1 := syscall.SyscallN( + procNotifyRouteChange2.Addr(), + uintptr(family), + callback, + callerContext, + uintptr(initNotif), + uintptr(unsafe.Pointer(handle)), + ) + if r1 != 0 { + if e1 != 0 { + return e1 + } + return syscall.EINVAL + } + return nil +} + +func cancelMibChangeNotify2(handle windows.Handle) error { + r1, _, e1 := syscall.SyscallN(procCancelMibChangeNotify2.Addr(), uintptr(handle)) + if r1 != 0 { + if e1 != 0 { + return e1 + } + return syscall.EINVAL + } + return nil +} + +// GetRoutesFromTable returns the current routing table from with prefixes only. +// It ccaches the result for 2 seconds to avoid blocking the caller. func GetRoutesFromTable() ([]netip.Prefix, error) { mux.Lock() defer mux.Unlock() @@ -117,6 +340,7 @@ func GetRoutesFromTable() ([]netip.Prefix, error) { return prefixList, nil } +// GetRoutes retrieves the current routing table using WMI. func GetRoutes() ([]Route, error) { var entries []MSFT_NetRoute @@ -146,8 +370,8 @@ func GetRoutes() ([]Route, error) { Name: entry.InterfaceAlias, } - if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) { - nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex))) + if nexthop.Is6() { + nexthop = addZone(nexthop, int(entry.InterfaceIndex)) } } @@ -161,33 +385,6 @@ func GetRoutes() ([]Route, error) { return routes, nil } -func GetNeighbors() ([]Neighbor, error) { - var entries []MSFT_NetNeighbor - query := `SELECT IPAddress, LinkLayerAddress, State, AddressFamily, InterfaceIndex, InterfaceAlias FROM MSFT_NetNeighbor` - if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil { - return nil, fmt.Errorf("failed to query MSFT_NetNeighbor: %w", err) - } - - var neighbors []Neighbor - for _, entry := range entries { - addr, err := netip.ParseAddr(entry.IPAddress) - if err != nil { - log.Warnf("Unable to parse neighbor IP address %s: %v", entry.IPAddress, err) - continue - } - neighbors = append(neighbors, Neighbor{ - IPAddress: addr, - LinkLayerAddress: entry.LinkLayerAddress, - State: entry.State, - AddressFamily: entry.AddressFamily, - InterfaceIndex: entry.InterfaceIndex, - InterfaceAlias: entry.InterfaceAlias, - }) - } - - return neighbors, nil -} - func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error { args := []string{"add", prefix.String()} @@ -220,3 +417,54 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error { func isCacheDisabled() bool { return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true" } + +func parseIPPrefix(prefix IP_ADDRESS_PREFIX, idx int) netip.Prefix { + ip := parseIP(prefix.Prefix, idx) + return netip.PrefixFrom(ip, int(prefix.PrefixLength)) +} + +func parseIP(addr SOCKADDR_INET, idx int) netip.Addr { + return parseIPGeneric(addr.sin6_family, addr.data, idx) +} + +func parseIPNexthop(addr SOCKADDR_INET_NEXTHOP, idx int) netip.Addr { + return parseIPGeneric(addr.sin6_family, addr.data, idx) +} + +func parseIPGeneric(family int16, data [24]byte, interfaceIndex int) netip.Addr { + switch family { + case windows.AF_INET: + ipv4 := binary.BigEndian.Uint32(data[:4]) + return netip.AddrFrom4([4]byte{ + byte(ipv4 >> 24), + byte(ipv4 >> 16), + byte(ipv4 >> 8), + byte(ipv4), + }) + + case windows.AF_INET6: + // The IPv6 address is stored after the 4-byte flowinfo field + var ipv6 [16]byte + copy(ipv6[:], data[4:20]) + ip := netip.AddrFrom16(ipv6) + + // Check if there's a non-zero scope_id + scopeID := binary.BigEndian.Uint32(data[20:24]) + if scopeID != 0 { + ip = ip.WithZone(strconv.FormatUint(uint64(scopeID), 10)) + } else if interfaceIndex != 0 { + ip = addZone(ip, interfaceIndex) + } + + return ip + } + + return netip.IPv4Unspecified() +} + +func addZone(ip netip.Addr, interfaceIndex int) netip.Addr { + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + ip = ip.WithZone(strconv.Itoa(interfaceIndex)) + } + return ip +}