diff --git a/client/cmd/root.go b/client/cmd/root.go index 999500787..1eca27d8c 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -32,6 +32,7 @@ const ( preSharedKeyFlag = "preshared-key" interfaceNameFlag = "interface-name" wireguardPortFlag = "wireguard-port" + networkMonitorFlag = "network-monitor" disableAutoConnectFlag = "disable-auto-connect" serverSSHAllowedFlag = "allow-server-ssh" extraIFaceBlackListFlag = "extra-iface-blacklist" @@ -62,6 +63,7 @@ var ( serverSSHAllowed bool interfaceName string wireguardPort uint16 + networkMonitor bool serviceName string autoConnectDisabled bool extraIFaceBlackList []string diff --git a/client/cmd/up.go b/client/cmd/up.go index c2c3c7c90..3af119c6b 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -40,6 +40,7 @@ func init() { upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") + upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") } @@ -116,6 +117,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { ic.WireguardPort = &p } + if cmd.Flag(networkMonitorFlag).Changed { + ic.NetworkMonitor = &networkMonitor + } + if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { ic.PreSharedKey = &preSharedKey } @@ -226,6 +231,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { loginRequest.WireguardPort = &wp } + if cmd.Flag(networkMonitorFlag).Changed { + loginRequest.NetworkMonitor = &networkMonitor + } + var loginErr error var loginResp *proto.LoginResponse diff --git a/client/internal/config.go b/client/internal/config.go index 5b3c61cbd..2beb853c0 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -48,6 +48,7 @@ type ConfigInput struct { RosenpassPermissive *bool InterfaceName *string WireguardPort *int + NetworkMonitor *bool DisableAutoConnect *bool ExtraIFaceBlackList []string } @@ -61,6 +62,7 @@ type Config struct { AdminURL *url.URL WgIface string WgPort int + NetworkMonitor bool IFaceBlackList []string DisableIPv6Discovery bool RosenpassEnabled bool @@ -188,6 +190,10 @@ func createNewConfig(input ConfigInput) (*Config, error) { config.WgPort = *input.WireguardPort } + if input.NetworkMonitor != nil { + config.NetworkMonitor = *input.NetworkMonitor + } + config.WgIface = iface.WgInterfaceDefault if input.InterfaceName != nil { config.WgIface = *input.InterfaceName @@ -279,6 +285,11 @@ func update(input ConfigInput) (*Config, error) { refresh = true } + if input.NetworkMonitor != nil { + config.NetworkMonitor = *input.NetworkMonitor + refresh = true + } + if input.WireguardPort != nil { config.WgPort = *input.WireguardPort refresh = true diff --git a/client/internal/connect.go b/client/internal/connect.go index c238fa31c..be71cdda9 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -249,7 +249,7 @@ func runClient( engineChan <- engine } - log.Print("Netbird engine started, my IP is: ", peerConfig.Address) + log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) <-engineCtx.Done() @@ -297,6 +297,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe DisableIPv6Discovery: config.DisableIPv6Discovery, WgPrivateKey: key, WgPort: config.WgPort, + NetworkMonitor: config.NetworkMonitor, SSHKey: []byte(config.SSHKey), NATExternalIPs: config.NATExternalIPs, CustomDNSAddress: config.CustomDNSAddress, diff --git a/client/internal/engine.go b/client/internal/engine.go index 9ee804a87..c81b13210 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -2,6 +2,7 @@ package internal import ( "context" + "errors" "fmt" "math/rand" "net" @@ -21,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" @@ -60,6 +62,9 @@ type EngineConfig struct { // WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine) WgPrivateKey wgtypes.Key + // NetworkMonitor is a flag to enable network monitoring + NetworkMonitor bool + // IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related) IFaceBlackList []string DisableIPv6Discovery bool @@ -114,9 +119,11 @@ type Engine struct { // clientRoutes is the most recent list of clientRoutes received from the Management Service clientRoutes route.HAMap - cancel context.CancelFunc + clientCtx context.Context + clientCancel context.CancelFunc - ctx context.Context + ctx context.Context + cancel context.CancelFunc wgInterface *iface.WGIface wgProxyFactory *wgproxy.Factory @@ -126,6 +133,8 @@ type Engine struct { // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 + networkWatcher *networkmonitor.NetworkWatcher + sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) sshServer nbssh.Server @@ -151,8 +160,8 @@ type Peer struct { // NewEngine creates a new Connection Engine func NewEngine( - ctx context.Context, - cancel context.CancelFunc, + clientCtx context.Context, + clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, config *EngineConfig, @@ -160,8 +169,8 @@ func NewEngine( statusRecorder *peer.Status, ) *Engine { return NewEngineWithProbes( - ctx, - cancel, + clientCtx, + clientCancel, signalClient, mgmClient, config, @@ -176,8 +185,8 @@ func NewEngine( // NewEngineWithProbes creates a new Connection Engine with probes attached func NewEngineWithProbes( - ctx context.Context, - cancel context.CancelFunc, + clientCtx context.Context, + clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, config *EngineConfig, @@ -188,9 +197,10 @@ func NewEngineWithProbes( relayProbe *Probe, wgProbe *Probe, ) *Engine { + return &Engine{ - ctx: ctx, - cancel: cancel, + clientCtx: clientCtx, + clientCancel: clientCancel, signal: signalClient, mgmClient: mgmClient, peerConns: make(map[string]*peer.Conn), @@ -202,7 +212,7 @@ func NewEngineWithProbes( networkSerial: 0, sshServerFunc: nbssh.DefaultSSHServer, statusRecorder: statusRecorder, - wgProxyFactory: wgproxy.NewFactory(config.WgPort), + networkWatcher: networkmonitor.New(), mgmProbe: mgmProbe, signalProbe: signalProbe, relayProbe: relayProbe, @@ -214,6 +224,13 @@ func (e *Engine) Stop() error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() + if e.cancel != nil { + e.cancel() + } + + // stopping network monitor first to avoid starting the engine again + e.networkWatcher.Stop() + err := e.removeAllPeers() if err != nil { return err @@ -222,7 +239,7 @@ func (e *Engine) Stop() error { e.clientRoutes = nil // very ugly but we want to remove peers from the WireGuard interface first before removing interface. - // Removing peers happens in the conn.CLose() asynchronously + // Removing peers happens in the conn.Close() asynchronously time.Sleep(500 * time.Millisecond) e.close() @@ -237,6 +254,13 @@ func (e *Engine) Start() error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() + if e.cancel != nil { + e.cancel() + } + e.ctx, e.cancel = context.WithCancel(e.clientCtx) + + e.wgProxyFactory = wgproxy.NewFactory(e.clientCtx, e.config.WgPort) + wgIface, err := e.newWgIface() if err != nil { log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err) @@ -320,6 +344,21 @@ func (e *Engine) Start() error { e.receiveManagementEvents() e.receiveProbeEvents() + if e.config.NetworkMonitor { + // starting network monitor at the very last to avoid disruptions + go e.networkWatcher.Start(e.ctx, func() { + log.Infof("Network monitor detected network change, restarting engine") + if err := e.Stop(); err != nil { + log.Errorf("Failed to stop engine: %v", err) + } + if err := e.Start(); err != nil { + log.Errorf("Failed to start engine: %v", err) + } + }) + } else { + log.Infof("Network monitor is disabled, not starting") + } + return nil } @@ -588,12 +627,12 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { // E.g. when a new peer has been registered and we are allowed to connect to it. func (e *Engine) receiveManagementEvents() { go func() { - err := e.mgmClient.Sync(e.handleSync) + err := e.mgmClient.Sync(e.ctx, e.handleSync) if err != nil { // happens if management is unavailable for a long time. // We want to cancel the operation of the whole client _ = CtxGetState(e.ctx).Wrap(ErrResetConnection) - e.cancel() + e.clientCancel() return } log.Debugf("stopped receiving updates from Management Service") @@ -869,11 +908,12 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) { conn.UpdateStunTurn(append(e.STUNs, e.TURNs...)) e.syncMsgMux.Unlock() - err := conn.Open() + err := conn.Open(e.ctx) if err != nil { log.Debugf("connection to peer %s failed: %v", peerKey, err) - switch err.(type) { - case *peer.ConnectionClosedError: + var connectionClosedError *peer.ConnectionClosedError + switch { + case errors.As(err, &connectionClosedError): // conn has been forced to close, so we exit the loop return default: @@ -984,7 +1024,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e func (e *Engine) receiveSignalEvents() { go func() { // connect to a stream of messages coming from the signal server - err := e.signal.Receive(func(msg *sProto.Message) error { + err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -1058,7 +1098,7 @@ func (e *Engine) receiveSignalEvents() { // happens if signal is unavailable for a long time. // We want to cancel the operation of the whole client _ = CtxGetState(e.ctx).Wrap(ErrResetConnection) - e.cancel() + e.clientCancel() return } }() @@ -1119,13 +1159,16 @@ func (e *Engine) parseNATExternalIPMappings() []string { } func (e *Engine) close() { - if err := e.wgProxyFactory.Free(); err != nil { - log.Errorf("failed closing ebpf proxy: %s", err) + if e.wgProxyFactory != nil { + if err := e.wgProxyFactory.Free(); err != nil { + log.Errorf("failed closing ebpf proxy: %s", err) + } } // stop/restore DNS first so dbus and friends don't complain because of a missing interface if e.dnsServer != nil { e.dnsServer.Stop() + e.dnsServer = nil } if e.routeManager != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 13a18cf39..1bac4145f 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -392,7 +392,7 @@ func TestEngine_Sync(t *testing.T) { // feed updates to Engine via mocked Management client updates := make(chan *mgmtProto.SyncResponse) defer close(updates) - syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error { + syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error { for msg := range updates { err := msgHandler(msg) if err != nil { diff --git a/client/internal/engine_watcher.go b/client/internal/engine_watcher.go deleted file mode 100644 index 5bf0569ce..000000000 --- a/client/internal/engine_watcher.go +++ /dev/null @@ -1 +0,0 @@ -package internal diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go new file mode 100644 index 000000000..71cf031ba --- /dev/null +++ b/client/internal/networkmonitor/monitor.go @@ -0,0 +1,15 @@ +package networkmonitor + +import ( + "context" +) + +// NetworkWatcher watches for changes in network configuration. +type NetworkWatcher struct { + cancel context.CancelFunc +} + +// New creates a new network monitor. +func New() *NetworkWatcher { + return &NetworkWatcher{} +} diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go new file mode 100644 index 000000000..e15c08d7e --- /dev/null +++ b/client/internal/networkmonitor/monitor_bsd.go @@ -0,0 +1,133 @@ +//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd + +package networkmonitor + +import ( + "context" + "fmt" + "net" + "net/netip" + "syscall" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/net/route" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager" +) + +func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error { + fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) + if err != nil { + return fmt.Errorf("failed to open routing socket: %v", err) + } + defer func() { + if err := unix.Close(fd); err != nil { + log.Errorf("Network monitor: failed to close routing socket: %v", err) + } + }() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + buf := make([]byte, 2048) + n, err := unix.Read(fd, buf) + if err != nil { + log.Errorf("Network monitor: failed to read from routing socket: %v", err) + continue + } + if n < unix.SizeofRtMsghdr { + log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n) + continue + } + + msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) + + switch msg.Type { + + // handle interface state changes + case unix.RTM_IFINFO: + ifinfo, err := parseInterfaceMessage(buf[:n]) + if err != nil { + log.Errorf("Network monitor: error parsing interface message: %v", err) + continue + } + if msg.Flags&unix.IFF_UP != 0 { + continue + } + if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) { + continue + } + + log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name) + callback() + + // handle route changes + case unix.RTM_ADD, syscall.RTM_DELETE: + route, err := parseRouteMessage(buf[:n]) + if err != nil { + log.Errorf("Network monitor: error parsing routing message: %v", err) + continue + } + + if !route.Dst.Addr().IsUnspecified() { + continue + } + + intf := "" + if route.Interface != nil { + intf = route.Interface.Name + } + switch msg.Type { + case unix.RTM_ADD: + log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) + callback() + case unix.RTM_DELETE: + if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 { + log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) + callback() + } + } + } + } + } +} + +func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) { + msgs, err := route.ParseRIB(route.RIBTypeInterface, buf) + if err != nil { + return nil, fmt.Errorf("parse RIB: %v", err) + } + + if len(msgs) != 1 { + return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) + } + + msg, ok := msgs[0].(*route.InterfaceMessage) + if !ok { + return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) + } + + return msg, nil +} + +func parseRouteMessage(buf []byte) (*routemanager.Route, error) { + msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) + if err != nil { + return nil, fmt.Errorf("parse RIB: %v", err) + } + + if len(msgs) != 1 { + return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) + } + + msg, ok := msgs[0].(*route.RouteMessage) + if !ok { + return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) + } + + return routemanager.MsgToRoute(msg) +} diff --git a/client/internal/networkmonitor/monitor_generic.go b/client/internal/networkmonitor/monitor_generic.go new file mode 100644 index 000000000..329246c8f --- /dev/null +++ b/client/internal/networkmonitor/monitor_generic.go @@ -0,0 +1,82 @@ +//go:build !ios && !android + +package networkmonitor + +import ( + "context" + "errors" + "net" + "net/netip" + "runtime/debug" + + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/routemanager" +) + +// Start begins watching for network changes and calls the callback function and stops when a change is detected. +func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) { + if nw.cancel != nil { + log.Warn("Network monitor: already running, stopping previous watcher") + nw.Stop() + } + + if ctx.Err() != nil { + log.Info("Network monitor: not starting, context is already cancelled") + return + } + + ctx, nw.cancel = context.WithCancel(ctx) + defer nw.Stop() + + var nexthop4, nexthop6 netip.Addr + var intf4, intf6 *net.Interface + + operation := func() error { + var errv4, errv6 error + nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified()) + nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified()) + + if errv4 != nil && errv6 != nil { + return errors.New("failed to get default next hops") + } + + if errv4 == nil { + log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name) + } + if errv6 == nil { + log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name) + } + + // continue if either route was found + return nil + } + + expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) + + if err := backoff.Retry(operation, expBackOff); err != nil { + log.Errorf("Network monitor: failed to get default next hops: %v", err) + return + } + + // recover in case sys ops panic + defer func() { + if r := recover(); r != nil { + log.Errorf("Network monitor: panic occurred: %v, stack trace: %s", r, string(debug.Stack())) + } + }() + + if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil && !errors.Is(err, context.Canceled) { + log.Errorf("Network monitor: failed to start: %v", err) + } +} + +// Stop stops the network monitor. +func (nw *NetworkWatcher) Stop() { + if nw.cancel != nil { + nw.cancel() + nw.cancel = nil + log.Info("Network monitor: stopped") + } +} diff --git a/client/internal/networkmonitor/monitor_linux.go b/client/internal/networkmonitor/monitor_linux.go new file mode 100644 index 000000000..f39f1235c --- /dev/null +++ b/client/internal/networkmonitor/monitor_linux.go @@ -0,0 +1,81 @@ +//go:build !android + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "syscall" + + log "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" +) + +func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error { + if intfv4 == nil && intfv6 == nil { + return errors.New("no interfaces available") + } + + linkChan := make(chan netlink.LinkUpdate) + done := make(chan struct{}) + defer close(done) + + if err := netlink.LinkSubscribe(linkChan, done); err != nil { + return fmt.Errorf("subscribe to link updates: %v", err) + } + + routeChan := make(chan netlink.RouteUpdate) + if err := netlink.RouteSubscribe(routeChan, done); err != nil { + return fmt.Errorf("subscribe to route updates: %v", err) + } + + log.Info("Network monitor: started") + for { + select { + case <-ctx.Done(): + return ctx.Err() + + // handle interface state changes + case update := <-linkChan: + if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) { + continue + } + + switch update.Header.Type { + case syscall.RTM_DELLINK: + log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name) + callback() + return nil + case syscall.RTM_NEWLINK: + if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown { + log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name) + callback() + return nil + } + } + + // handle route changes + case route := <-routeChan: + // default route and main table + if route.Dst != nil || route.Table != syscall.RT_TABLE_MAIN { + continue + } + switch route.Type { + // triggered on added/replaced routes + case syscall.RTM_NEWROUTE: + log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex) + callback() + return nil + case syscall.RTM_DELROUTE: + if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) { + log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) + callback() + return nil + } + } + } + } +} diff --git a/client/internal/networkmonitor/monitor_mobile.go b/client/internal/networkmonitor/monitor_mobile.go new file mode 100644 index 000000000..988f296bb --- /dev/null +++ b/client/internal/networkmonitor/monitor_mobile.go @@ -0,0 +1,11 @@ +//go:build ios || android + +package networkmonitor + +import "context" + +func (nw *NetworkWatcher) Start(context.Context, func()) { +} + +func (nw *NetworkWatcher) Stop() { +} diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/monitor_windows.go new file mode 100644 index 000000000..f6c5d963f --- /dev/null +++ b/client/internal/networkmonitor/monitor_windows.go @@ -0,0 +1,215 @@ +package networkmonitor + +import ( + "context" + "fmt" + "net" + "net/netip" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/routemanager" +) + +const ( + unreachable = 0 + incomplete = 1 + probe = 2 + delay = 3 + stale = 4 + reachable = 5 + permanent = 6 + tbd = 7 +) + +const interval = 10 * time.Second + +func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error { + var neighborv4, neighborv6 *routemanager.Neighbor + { + initialNeighbors, err := getNeighbors() + if err != nil { + return fmt.Errorf("get neighbors: %w", err) + } + + if n, ok := initialNeighbors[nexthopv4]; ok { + neighborv4 = &n + } + if n, ok := initialNeighbors[nexthopv6]; ok { + neighborv6 = &n + } + } + log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6) + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) { + callback() + return nil + } + } + } +} + +func changed( + nexthopv4 netip.Addr, + intfv4 *net.Interface, + neighborv4 *routemanager.Neighbor, + nexthopv6 netip.Addr, + intfv6 *net.Interface, + neighborv6 *routemanager.Neighbor, +) bool { + neighbors, err := getNeighbors() + if err != nil { + log.Errorf("network monitor: error fetching current neighbors: %v", err) + return false + } + if neighborChanged(nexthopv4, neighborv4, neighbors) || neighborChanged(nexthopv6, neighborv6, neighbors) { + return true + } + + routes, err := getRoutes() + if err != nil { + log.Errorf("network monitor: error fetching current routes: %v", err) + return false + } + + if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) { + return true + } + + return false +} + +// routeChanged checks if the default routes still point to our nexthop/interface +func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool { + if !nexthop.IsValid() { + return false + } + + var unspec netip.Prefix + if nexthop.Is6() { + unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0) + } else { + unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + } + + if r, ok := routes[unspec]; ok { + if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 { + intf := "" + if r.Interface != nil { + intf = r.Interface.Name + } + log.Infof("network monitor: default route changed: %s via %s (%s)", r.Destination, r.Nexthop, intf) + return true + } + } else { + log.Infof("network monitor: default route is gone") + return true + } + + return false + +} + +func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool { + if neighbor == nil { + return false + } + + // TODO: consider non-local nexthops, e.g. on point-to-point interfaces + if n, ok := neighbors[nexthop]; ok { + if n.State != reachable && n.State != permanent { + log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State)) + return true + } else if n.InterfaceIndex != neighbor.InterfaceIndex { + log.Infof( + "network monitor: neighbor %s (%s) changed interface from '%s' (%d) to '%s' (%d): %s", + neighbor.IPAddress, + neighbor.LinkLayerAddress, + neighbor.InterfaceAlias, + neighbor.InterfaceIndex, + n.InterfaceAlias, + n.InterfaceIndex, + stateFromInt(n.State), + ) + return true + } + } else { + log.Infof("network monitor: neighbor %s (%s) is gone", neighbor.IPAddress, neighbor.LinkLayerAddress) + return true + } + + return false +} + +func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) { + entries, err := routemanager.GetNeighbors() + if err != nil { + return nil, fmt.Errorf("get neighbors: %w", err) + } + + neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries)) + for _, entry := range entries { + neighbours[entry.IPAddress] = entry + } + + return neighbours, nil +} + +func getRoutes() (map[netip.Prefix]routemanager.Route, error) { + entries, err := routemanager.GetRoutes() + if err != nil { + return nil, fmt.Errorf("get routes: %w", err) + } + + routes := make(map[netip.Prefix]routemanager.Route, len(entries)) + for _, entry := range entries { + routes[entry.Destination] = entry + } + + return routes, nil +} + +func stateFromInt(state uint8) string { + switch state { + case unreachable: + return "unreachable" + case incomplete: + return "incomplete" + case probe: + return "probe" + case delay: + return "delay" + case stale: + return "stale" + case reachable: + return "reachable" + case permanent: + return "permanent" + case tbd: + return "tbd" + default: + return "unknown" + } +} + +func compareIntf(a, b *net.Interface) int { + if a == nil && b == nil { + return 0 + } + if a == nil { + return -1 + } + if b == nil { + return 1 + } + return a.Index - b.Index +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index a0da82b8d..1ee8cdd79 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -276,7 +276,7 @@ func (conn *Conn) candidateTypes() []ice.CandidateType { // Open opens connection to the remote peer starting ICE candidate gathering process. // Blocks until connection has been closed or connection timeout. // ConnStatus will be set accordingly -func (conn *Conn) Open() error { +func (conn *Conn) Open(ctx context.Context) error { log.Debugf("trying to connect to peer %s", conn.config.Key) peerState := State{ @@ -336,7 +336,7 @@ func (conn *Conn) Open() error { // at this point we received offer/answer and we are ready to gather candidates conn.mu.Lock() conn.status = StatusConnecting - conn.ctx, conn.notifyDisconnected = context.WithCancel(context.Background()) + conn.ctx, conn.notifyDisconnected = context.WithCancel(ctx) defer conn.notifyDisconnected() conn.mu.Unlock() @@ -423,7 +423,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem var endpoint net.Addr if isRelayCandidate(pair.Local) { log.Debugf("setup relay connection") - conn.wgProxy = conn.wgProxyFactory.GetProxy() + conn.wgProxy = conn.wgProxyFactory.GetProxy(conn.ctx) endpoint, err = conn.wgProxy.AddTurnConn(remoteConn) if err != nil { return nil, err @@ -448,9 +448,11 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { if conn.wgProxy != nil { - _ = conn.wgProxy.CloseConn() + if err := conn.wgProxy.CloseConn(); err != nil { + log.Warnf("Failed to close turn connection: %v", err) + } } - return nil, err + return nil, fmt.Errorf("update peer: %w", err) } conn.status = StatusConnected @@ -730,7 +732,7 @@ func (conn *Conn) Close() error { // before conn.Open() another update from management arrives with peers: [1,2,3,4,5] // engine adds a new Conn for 4 and 5 // therefore peer 4 has 2 Conn objects - log.Warnf("connection has been already closed or attempted closing not started coonection %s", conn.config.Key) + log.Warnf("Connection has been already closed or attempted closing not started connection %s", conn.config.Key) return NewConnectionAlreadyClosed(conn.config.Key) } } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 5c550d0d7..c16134808 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -1,6 +1,7 @@ package peer import ( + "context" "sync" "testing" "time" @@ -35,7 +36,7 @@ func TestNewConn_interfaceFilter(t *testing.T) { } func TestConn_GetKey(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -50,7 +51,7 @@ func TestConn_GetKey(t *testing.T) { } func TestConn_OnRemoteOffer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -87,7 +88,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { } func TestConn_OnRemoteAnswer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -123,7 +124,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg.Wait() } func TestConn_Status(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -153,7 +154,7 @@ func TestConn_Status(t *testing.T) { } func TestConn_Close(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go index 1f37a8a3c..bc506411c 100644 --- a/client/internal/routemanager/systemops.go +++ b/client/internal/routemanager/systemops.go @@ -35,7 +35,7 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { addr = netip.IPv6Unspecified() } - defaultGateway, _, err := getNextHop(addr) + defaultGateway, _, err := GetNextHop(addr) if err != nil && !errors.Is(err, ErrRouteNotFound) { return fmt.Errorf("get existing route gateway: %s", err) } @@ -60,7 +60,7 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { return nil } - gatewayHop, intf, err := getNextHop(defaultGateway) + gatewayHop, intf, err := GetNextHop(defaultGateway) if err != nil && !errors.Is(err, ErrRouteNotFound) { return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) } @@ -69,14 +69,14 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { return addToRouteTable(gatewayPrefix, gatewayHop, intf) } -func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { +func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { r, err := netroute.New() if err != nil { return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) } intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) if err != nil { - log.Warnf("Failed to get route for %s: %v", ip, err) + log.Debugf("Failed to get route for %s: %v", ip, err) return netip.Addr{}, nil, ErrRouteNotFound } @@ -163,7 +163,7 @@ func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNe } // Determine the exit interface and next hop for the prefix, so we can add a specific route - nexthop, intf, err := getNextHop(addr) + nexthop, intf, err := GetNextHop(addr) if err != nil { return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err) } @@ -319,11 +319,11 @@ func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { } func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) + initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified()) if err != nil && !errors.Is(err, ErrRouteNotFound) { log.Errorf("Unable to get initial v4 default next hop: %v", err) } - initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) + initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified()) if err != nil && !errors.Is(err, ErrRouteNotFound) { log.Errorf("Unable to get initial v6 default next hop: %v", err) } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index f8f76ed45..a3548a1f1 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -3,10 +3,11 @@ package routemanager import ( + "errors" "fmt" "net" "net/netip" - "errors" + "strconv" "syscall" "time" @@ -15,17 +16,22 @@ import ( "golang.org/x/net/route" ) +type Route struct { + Dst netip.Prefix + Gw netip.Addr + Interface *net.Interface +} -// TODO: fix here with retry and backoff func getRoutesFromTable() ([]netip.Prefix, error) { tab, err := retryFetchRIB() if err != nil { - return nil, err + return nil, fmt.Errorf("fetch RIB: %v", err) } msgs, err := route.ParseRIB(route.RIBTypeRoute, tab) if err != nil { - return nil, err + return nil, fmt.Errorf("parse RIB: %v", err) } + var prefixList []netip.Prefix for _, msg := range msgs { m := msg.(*route.RouteMessage) @@ -33,7 +39,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) { if m.Version < 3 || m.Version > 5 { return nil, fmt.Errorf("unexpected RIB message version: %d", m.Version) } - if m.Type != 4 /* RTM_GET */ { + if m.Type != syscall.RTM_GET { return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type) } @@ -42,28 +48,13 @@ func getRoutesFromTable() ([]netip.Prefix, error) { continue } - if len(m.Addrs) < 3 { - log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs) + route, err := MsgToRoute(m) + if err != nil { + log.Warnf("Failed to parse route message: %v", err) continue } - - addr, ok := toNetIPAddr(m.Addrs[0]) - if !ok { - continue - } - - cidr := 32 - if mask := m.Addrs[2]; mask != nil { - cidr, ok = toCIDR(mask) - if !ok { - log.Debugf("Unexpected RIB message Addrs[2]: %v", mask) - continue - } - } - - routePrefix := netip.PrefixFrom(addr, cidr) - if routePrefix.IsValid() { - prefixList = append(prefixList, routePrefix) + if route.Dst.IsValid() { + prefixList = append(prefixList, route.Dst) } } return prefixList, nil @@ -75,7 +66,7 @@ func retryFetchRIB() ([]byte, error) { var err error out, err = route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) if errors.Is(err, syscall.ENOMEM) { - log.Debug("retrying fetchRIB due to 'cannot allocate memory' error") + log.Debug("~etrying fetchRIB due to 'cannot allocate memory' error") return err } else if err != nil { return backoff.Permanent(err) @@ -95,22 +86,74 @@ func retryFetchRIB() ([]byte, error) { return out, nil } -func toNetIPAddr(a route.Addr) (netip.Addr, bool) { +func toNetIP(a route.Addr) netip.Addr { switch t := a.(type) { case *route.Inet4Addr: - return netip.AddrFrom4(t.IP), true + return netip.AddrFrom4(t.IP) + case *route.Inet6Addr: + ip := netip.AddrFrom16(t.IP) + if t.ZoneID != 0 { + ip.WithZone(strconv.Itoa(t.ZoneID)) + } + return ip default: - return netip.Addr{}, false + return netip.Addr{} } } -func toCIDR(a route.Addr) (int, bool) { +func ones(a route.Addr) (int, error) { switch t := a.(type) { case *route.Inet4Addr: - mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) - cidr, _ := mask.Size() - return cidr, true + mask, _ := net.IPMask(t.IP[:]).Size() + return mask, nil + case *route.Inet6Addr: + mask, _ := net.IPMask(t.IP[:]).Size() + return mask, nil default: - return 0, false + return 0, fmt.Errorf("unexpected address type: %T", a) } } + +func MsgToRoute(msg *route.RouteMessage) (*Route, error) { + dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2] + + addr := toNetIP(dstIP) + + var nexthopAddr netip.Addr + var nexthopIntf *net.Interface + + switch t := nexthop.(type) { + case *route.Inet4Addr, *route.Inet6Addr: + nexthopAddr = toNetIP(t) + case *route.LinkAddr: + nexthopIntf = &net.Interface{ + Index: t.Index, + Name: t.Name, + } + default: + return nil, fmt.Errorf("unexpected next hop type: %T", t) + } + + var prefix netip.Prefix + + if dstMask == nil { + if addr.Is4() { + prefix = netip.PrefixFrom(addr, 32) + } else { + prefix = netip.PrefixFrom(addr, 128) + } + } else { + bits, err := ones(dstMask) + if err != nil { + return nil, fmt.Errorf("failed to parse mask: %v", dstMask) + } + prefix = netip.PrefixFrom(addr, bits) + } + + return &Route{ + Dst: prefix, + Gw: nexthopAddr, + Interface: nexthopIntf, + }, nil + +} diff --git a/client/internal/routemanager/systemops_bsd_test.go b/client/internal/routemanager/systemops_bsd_test.go new file mode 100644 index 000000000..81bca504c --- /dev/null +++ b/client/internal/routemanager/systemops_bsd_test.go @@ -0,0 +1,57 @@ +//go:build darwin || dragonfly || freebsd || netbsd || openbsd + +package routemanager + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/net/route" +) + +func TestBits(t *testing.T) { + tests := []struct { + name string + addr route.Addr + want int + wantErr bool + }{ + { + name: "IPv4 all ones", + addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}}, + want: 32, + }, + { + name: "IPv4 normal mask", + addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}}, + want: 24, + }, + { + name: "IPv6 all ones", + addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}}, + want: 128, + }, + { + name: "IPv6 normal mask", + addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}}, + want: 64, + }, + { + name: "Unsupported type", + addr: &route.LinkAddr{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ones(tt.addr) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go index 8a92ac579..8bcf06dce 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -87,10 +87,10 @@ func TestAddRemoveRoutes(t *testing.T) { err = removeVPNRoute(testCase.prefix, intf) require.NoError(t, err, "genericRemoveVPNRoute should not return err") - prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") + prefixGateway, _, err := GetNextHop(testCase.prefix.Addr()) + require.NoError(t, err, "GetNextHop should not return err") - internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) require.NoError(t, err) if testCase.shouldBeRemoved { @@ -104,7 +104,7 @@ func TestAddRemoveRoutes(t *testing.T) { } func TestGetNextHop(t *testing.T) { - gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) } @@ -130,7 +130,7 @@ func TestGetNextHop(t *testing.T) { } } - localIP, _, err := getNextHop(testingPrefix.Addr()) + localIP, _, err := GetNextHop(testingPrefix.Addr()) if err != nil { t.Fatal("shouldn't return error: ", err) } @@ -146,7 +146,7 @@ func TestGetNextHop(t *testing.T) { } func TestAddExistAndRemoveRoute(t *testing.T) { - defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) t.Log("defaultGateway: ", defaultGateway) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) @@ -410,8 +410,8 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf return } - prefixGateway, _, err := getNextHop(prefix.Addr()) - require.NoError(t, err, "getNextHop should not return err") + prefixGateway, _, err := GetNextHop(prefix.Addr()) + require.NoError(t, err, "GetNextHop should not return err") if invert { assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") } else { diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index f9e75e2ed..cfe2639ec 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -20,9 +20,36 @@ import ( "github.com/netbirdio/netbird/iface" ) -type Win32_IP4RouteTable struct { - Destination string - Mask string +type MSFT_NetRoute struct { + DestinationPrefix string + NextHop string + InterfaceIndex int32 + InterfaceAlias string + AddressFamily uint16 +} + +type Route struct { + Destination netip.Prefix + Nexthop netip.Addr + Interface *net.Interface +} + +type MSFT_NetNeighbor struct { + IPAddress string + LinkLayerAddress string + State uint8 + AddressFamily uint16 + InterfaceIndex uint32 + InterfaceAlias string +} + +type Neighbor struct { + IPAddress netip.Addr + LinkLayerAddress string + State uint8 + AddressFamily uint16 + InterfaceIndex uint32 + InterfaceAlias string } var prefixList []netip.Prefix @@ -43,44 +70,92 @@ func getRoutesFromTable() ([]netip.Prefix, error) { mux.Lock() defer mux.Unlock() - query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" - // If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second { return prefixList, nil } - var routes []Win32_IP4RouteTable - err := wmi.Query(query, &routes) + routes, err := GetRoutes() if err != nil { return nil, fmt.Errorf("get routes: %w", err) } prefixList = nil for _, route := range routes { - addr, err := netip.ParseAddr(route.Destination) - if err != nil { - log.Warnf("Unable to parse route destination %s: %v", route.Destination, err) - continue - } - maskSlice := net.ParseIP(route.Mask).To4() - if maskSlice == nil { - log.Warnf("Unable to parse route mask %s", route.Mask) - continue - } - mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3]) - cidr, _ := mask.Size() - - routePrefix := netip.PrefixFrom(addr, cidr) - if routePrefix.IsValid() && routePrefix.Addr().Is4() { - prefixList = append(prefixList, routePrefix) - } + prefixList = append(prefixList, route.Destination) } lastUpdate = time.Now() return prefixList, nil } +func GetRoutes() ([]Route, error) { + var entries []MSFT_NetRoute + + query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute` + if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil { + return nil, fmt.Errorf("get routes: %w", err) + } + + var routes []Route + for _, entry := range entries { + dest, err := netip.ParsePrefix(entry.DestinationPrefix) + if err != nil { + log.Warnf("Unable to parse route destination %s: %v", entry.DestinationPrefix, err) + continue + } + + nexthop, err := netip.ParseAddr(entry.NextHop) + if err != nil { + log.Warnf("Unable to parse route next hop %s: %v", entry.NextHop, err) + continue + } + + var intf *net.Interface + if entry.InterfaceIndex != 0 { + intf = &net.Interface{ + Index: int(entry.InterfaceIndex), + Name: entry.InterfaceAlias, + } + } + + routes = append(routes, Route{ + Destination: dest, + Nexthop: nexthop, + Interface: intf, + }) + } + + return routes, nil +} + +func GetNeighbors() ([]Neighbor, error) { + var entries []MSFT_NetNeighbor + query := `SELECT IPAddress, LinkLayerAddress, State, AddressFamily, InterfaceIndex, InterfaceAlias FROM MSFT_NetNeighbor` + if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil { + return nil, fmt.Errorf("failed to query MSFT_NetNeighbor: %w", err) + } + + var neighbors []Neighbor + for _, entry := range entries { + addr, err := netip.ParseAddr(entry.IPAddress) + if err != nil { + log.Warnf("Unable to parse neighbor IP address %s: %v", entry.IPAddress, err) + continue + } + neighbors = append(neighbors, Neighbor{ + IPAddress: addr, + LinkLayerAddress: entry.LinkLayerAddress, + State: entry.State, + AddressFamily: entry.AddressFamily, + InterfaceIndex: entry.InterfaceIndex, + InterfaceAlias: entry.InterfaceAlias, + }) + } + + return neighbors, nil +} + func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { args := []string{"add", prefix.String()} diff --git a/client/internal/wgproxy/factory.go b/client/internal/wgproxy/factory.go index a6d170519..f4eb150b0 100644 --- a/client/internal/wgproxy/factory.go +++ b/client/internal/wgproxy/factory.go @@ -1,15 +1,17 @@ package wgproxy +import "context" + type Factory struct { wgPort int ebpfProxy Proxy } -func (w *Factory) GetProxy() Proxy { +func (w *Factory) GetProxy(ctx context.Context) Proxy { if w.ebpfProxy != nil { return w.ebpfProxy } - return NewWGUserSpaceProxy(w.wgPort) + return NewWGUserSpaceProxy(ctx, w.wgPort) } func (w *Factory) Free() error { diff --git a/client/internal/wgproxy/factory_linux.go b/client/internal/wgproxy/factory_linux.go index e8d48b35b..0262994d7 100644 --- a/client/internal/wgproxy/factory_linux.go +++ b/client/internal/wgproxy/factory_linux.go @@ -3,14 +3,16 @@ package wgproxy import ( + "context" + log "github.com/sirupsen/logrus" ) -func NewFactory(wgPort int) *Factory { +func NewFactory(ctx context.Context, wgPort int) *Factory { f := &Factory{wgPort: wgPort} - ebpfProxy := NewWGEBPFProxy(wgPort) - err := ebpfProxy.Listen() + ebpfProxy := NewWGEBPFProxy(ctx, wgPort) + err := ebpfProxy.listen() if err != nil { log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) return f diff --git a/client/internal/wgproxy/factory_nonlinux.go b/client/internal/wgproxy/factory_nonlinux.go index c538efd84..33a235c4a 100644 --- a/client/internal/wgproxy/factory_nonlinux.go +++ b/client/internal/wgproxy/factory_nonlinux.go @@ -2,6 +2,8 @@ package wgproxy -func NewFactory(wgPort int) *Factory { +import "context" + +func NewFactory(ctx context.Context, wgPort int) *Factory { return &Factory{wgPort: wgPort} } diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index 16ebf0f35..b88df73a0 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -6,7 +6,7 @@ import ( // Proxy is a transfer layer between the Turn connection and the WireGuard type Proxy interface { - AddTurnConn(urnConn net.Conn) (net.Addr, error) + AddTurnConn(turnConn net.Conn) (net.Addr, error) CloseConn() error Free() error } diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index 22d327376..01e8766e8 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -3,6 +3,7 @@ package wgproxy import ( + "context" "fmt" "io" "net" @@ -22,7 +23,11 @@ import ( // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { - ebpfManager ebpfMgr.Manager + ebpfManager ebpfMgr.Manager + + ctx context.Context + cancel context.CancelFunc + lastUsedPort uint16 localWGListenPort int @@ -34,7 +39,7 @@ type WGEBPFProxy struct { } // NewWGEBPFProxy create new WGEBPFProxy instance -func NewWGEBPFProxy(wgPort int) *WGEBPFProxy { +func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy { log.Debugf("instantiate ebpf proxy") wgProxy := &WGEBPFProxy{ localWGListenPort: wgPort, @@ -42,11 +47,13 @@ func NewWGEBPFProxy(wgPort int) *WGEBPFProxy { lastUsedPort: 0, turnConnStore: make(map[uint16]net.Conn), } + wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx) + return wgProxy } -// Listen load ebpf program and listen the proxy -func (p *WGEBPFProxy) Listen() error { +// listen load ebpf program and listen the proxy +func (p *WGEBPFProxy) listen() error { pl := portLookup{} wgPorxyPort, err := pl.searchFreePort() if err != nil { @@ -72,7 +79,7 @@ func (p *WGEBPFProxy) Listen() error { if err != nil { cErr := p.Free() if cErr != nil { - log.Errorf("failed to close the wgproxy: %s", cErr) + log.Errorf("Failed to close the wgproxy: %s", cErr) } return err } @@ -102,6 +109,7 @@ func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) { // CloseConn doing nothing because this type of proxy implementation does not store the connection func (p *WGEBPFProxy) CloseConn() error { + p.cancel() return nil } @@ -131,19 +139,27 @@ func (p *WGEBPFProxy) Free() error { func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) { buf := make([]byte, 1500) + var err error + defer func() { + p.removeTurnConn(endpointPort) + }() for { - n, err := remoteConn.Read(buf) - if err != nil { - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } - p.removeTurnConn(endpointPort) - log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err) + select { + case <-p.ctx.Done(): return - } - err = p.sendPkg(buf[:n], endpointPort) - if err != nil { - log.Errorf("failed to write out turn pkg to local conn: %v", err) + default: + var n int + n, err = remoteConn.Read(buf) + if err != nil { + if err != io.EOF { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) + } + return + } + err = p.sendPkg(buf[:n], endpointPort) + if err != nil { + log.Errorf("failed to write out turn pkg to local conn: %v", err) + } } } } @@ -152,23 +168,28 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) { func (p *WGEBPFProxy) proxyToRemote() { buf := make([]byte, 1500) for { - n, addr, err := p.conn.ReadFromUDP(buf) - if err != nil { - log.Errorf("failed to read UDP pkg from WG: %s", err) + select { + case <-p.ctx.Done(): return - } + default: + n, addr, err := p.conn.ReadFromUDP(buf) + if err != nil { + log.Errorf("failed to read UDP pkg from WG: %s", err) + return + } - p.turnConnMutex.Lock() - conn, ok := p.turnConnStore[uint16(addr.Port)] - p.turnConnMutex.Unlock() - if !ok { - log.Infof("turn conn not found by port: %d", addr.Port) - continue - } + p.turnConnMutex.Lock() + conn, ok := p.turnConnStore[uint16(addr.Port)] + p.turnConnMutex.Unlock() + if !ok { + log.Infof("turn conn not found by port: %d", addr.Port) + continue + } - _, err = conn.Write(buf[:n]) - if err != nil { - log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err) + _, err = conn.Write(buf[:n]) + if err != nil { + log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err) + } } } } @@ -266,15 +287,17 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { err := udpH.SetNetworkLayerForChecksum(ipH) if err != nil { - return err + return fmt.Errorf("set network layer for checksum: %w", err) } layerBuffer := gopacket.NewSerializeBuffer() err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload) if err != nil { - return err + return fmt.Errorf("serialize layers: %w", err) } - _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}) - return err + if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { + return fmt.Errorf("write to raw conn: %w", err) + } + return nil } diff --git a/client/internal/wgproxy/proxy_ebpf_test.go b/client/internal/wgproxy/proxy_ebpf_test.go index 84c74cdcc..821e64218 100644 --- a/client/internal/wgproxy/proxy_ebpf_test.go +++ b/client/internal/wgproxy/proxy_ebpf_test.go @@ -3,11 +3,12 @@ package wgproxy import ( + "context" "testing" ) func TestWGEBPFProxy_connStore(t *testing.T) { - wgProxy := NewWGEBPFProxy(1) + wgProxy := NewWGEBPFProxy(context.Background(), 1) p, _ := wgProxy.storeTurnConn(nil) if p != 1 { @@ -27,7 +28,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) { } func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { - wgProxy := NewWGEBPFProxy(1) + wgProxy := NewWGEBPFProxy(context.Background(), 1) _, _ = wgProxy.storeTurnConn(nil) wgProxy.lastUsedPort = 65535 @@ -43,7 +44,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { } func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { - wgProxy := NewWGEBPFProxy(1) + wgProxy := NewWGEBPFProxy(context.Background(), 1) for i := 0; i < 65535; i++ { _, _ = wgProxy.storeTurnConn(nil) diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go index 17ebfbc49..234ea2a42 100644 --- a/client/internal/wgproxy/proxy_userspace.go +++ b/client/internal/wgproxy/proxy_userspace.go @@ -21,21 +21,21 @@ type WGUserSpaceProxy struct { } // NewWGUserSpaceProxy instantiate a user space WireGuard proxy -func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { - log.Debugf("instantiate new userspace proxy") +func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy { + log.Debugf("Initializing new user space proxy with port %d", wgPort) p := &WGUserSpaceProxy{ localWGListenPort: wgPort, } - p.ctx, p.cancel = context.WithCancel(context.Background()) + p.ctx, p.cancel = context.WithCancel(ctx) return p } // AddTurnConn start the proxy with the given remote conn -func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) { - p.remoteConn = remoteConn +func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) { + p.remoteConn = turnConn var err error - p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort)) + p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) return nil, err diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 3ec3ff10f..ab3fb8dd8 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -120,6 +120,7 @@ type LoginRequest struct { ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"` RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"` ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` + NetworkMonitor *bool `protobuf:"varint,18,opt,name=networkMonitor,proto3,oneof" json:"networkMonitor,omitempty"` } func (x *LoginRequest) Reset() { @@ -274,6 +275,13 @@ func (x *LoginRequest) GetExtraIFaceBlacklist() []string { return nil } +func (x *LoginRequest) GetNetworkMonitor() bool { + if x != nil && x.NetworkMonitor != nil { + return *x.NetworkMonitor + } + return false +} + type LoginResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1893,7 +1901,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x8f, 0x07, 0x0a, 0x0c, 0x4c, 0x6f, + 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xcf, 0x07, 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, @@ -1941,16 +1949,20 @@ var file_daemon_proto_rawDesc = []byte{ 0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, 0x65, 0x42, 0x6c, 0x61, 0x63, 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x11, 0x20, 0x03, 0x28, 0x09, 0x52, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, 0x65, 0x42, 0x6c, 0x61, 0x63, - 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, - 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, - 0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, - 0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, - 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, - 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, - 0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, - 0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, - 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xb5, 0x01, 0x0a, 0x0d, + 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x12, 0x2b, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x18, 0x12, 0x20, 0x01, 0x28, 0x08, 0x48, 0x07, 0x52, + 0x0e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x88, + 0x01, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, 0x6e, 0x74, 0x65, + 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, 0x69, + 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, 0x0a, 0x15, 0x5f, + 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, + 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, + 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, + 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, + 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, + 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x42, 0x11, 0x0a, 0x0f, 0x5f, 0x6e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x22, 0xb5, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index a647f267b..e90c7b063 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -87,6 +87,8 @@ message LoginRequest { optional bool rosenpassPermissive = 16; repeated string extraIFaceBlacklist = 17; + + optional bool networkMonitor = 18; } message LoginResponse { diff --git a/client/server/server.go b/client/server/server.go index e0e9504fa..db303e99e 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -358,6 +358,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.latestConfigInput.WireguardPort = &port } + if msg.NetworkMonitor != nil { + inputConfig.NetworkMonitor = msg.NetworkMonitor + s.latestConfigInput.NetworkMonitor = msg.NetworkMonitor + } + if len(msg.ExtraIFaceBlacklist) > 0 { inputConfig.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist diff --git a/management/client/client.go b/management/client/client.go index 166fd02b1..928092a40 100644 --- a/management/client/client.go +++ b/management/client/client.go @@ -1,16 +1,18 @@ package client import ( + "context" "io" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/management/proto" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) type Client interface { io.Closer - Sync(msgHandler func(msg *proto.SyncResponse) error) error + Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error GetServerPublicKey() (*wgtypes.Key, error) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) diff --git a/management/client/client_test.go b/management/client/client_test.go index 30f91c73b..5d04fa591 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -17,6 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" mgmt "github.com/netbirdio/netbird/management/server" @@ -255,7 +256,7 @@ func TestClient_Sync(t *testing.T) { ch := make(chan *mgmtProto.SyncResponse, 1) go func() { - err = client.Sync(func(msg *mgmtProto.SyncResponse) error { + err = client.Sync(context.Background(), func(msg *mgmtProto.SyncResponse) error { ch <- msg return nil }) diff --git a/management/client/grpc.go b/management/client/grpc.go index c6d1d5753..df687a160 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -113,8 +113,8 @@ func (c *GrpcClient) ready() bool { // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Blocking request. The result will be sent via msgHandler callback function -func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error { - backOff := defaultBackoff(c.ctx) +func (c *GrpcClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error { + backOff := defaultBackoff(ctx) operation := func() error { log.Debugf("management connection state %v", c.conn.GetState()) @@ -123,7 +123,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error if connState == connectivity.Shutdown { return backoff.Permanent(fmt.Errorf("connection to management has been shut down")) } else if !(connState == connectivity.Ready || connState == connectivity.Idle) { - c.conn.WaitForStateChange(c.ctx, connState) + c.conn.WaitForStateChange(ctx, connState) return fmt.Errorf("connection to management is not ready and in %s state", connState) } @@ -133,7 +133,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error return err } - ctx, cancelStream := context.WithCancel(c.ctx) + ctx, cancelStream := context.WithCancel(ctx) defer cancelStream() stream, err := c.connectToStream(ctx, *serverPubKey) if err != nil { diff --git a/management/client/mock.go b/management/client/mock.go index 042f837b8..02a5ade3a 100644 --- a/management/client/mock.go +++ b/management/client/mock.go @@ -1,6 +1,8 @@ package client import ( + "context" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/system" @@ -9,7 +11,7 @@ import ( type MockClient struct { CloseFunc func() error - SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error + SyncFunc func(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error GetServerPublicKeyFunc func() (*wgtypes.Key, error) RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) @@ -28,11 +30,11 @@ func (m *MockClient) Close() error { return m.CloseFunc() } -func (m *MockClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error { +func (m *MockClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error { if m.SyncFunc == nil { return nil } - return m.SyncFunc(msgHandler) + return m.SyncFunc(ctx, msgHandler) } func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) { diff --git a/signal/client/client.go b/signal/client/client.go index dc73b2ce5..e4d9d74b3 100644 --- a/signal/client/client.go +++ b/signal/client/client.go @@ -1,6 +1,7 @@ package client import ( + "context" "fmt" "io" "strings" @@ -33,7 +34,7 @@ type Client interface { io.Closer StreamConnected() bool GetStatus() Status - Receive(msgHandler func(msg *proto.Message) error) error + Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error Ready() bool IsHealthy() bool WaitStreamConnected() diff --git a/signal/client/client_test.go b/signal/client/client_test.go index c168783e7..8dea535f2 100644 --- a/signal/client/client_test.go +++ b/signal/client/client_test.go @@ -55,7 +55,7 @@ var _ = Describe("GrpcClient", func() { keyA, _ := wgtypes.GenerateKey() clientA := createSignalClient(addr, keyA) go func() { - err := clientA.Receive(func(msg *sigProto.Message) error { + err := clientA.Receive(context.Background(), func(msg *sigProto.Message) error { payloadReceivedOnA = msg.GetBody().GetPayload() featuresSupportedReceivedOnA = msg.GetBody().GetFeaturesSupported() msgReceived.Done() @@ -72,7 +72,7 @@ var _ = Describe("GrpcClient", func() { clientB := createSignalClient(addr, keyB) go func() { - err := clientB.Receive(func(msg *sigProto.Message) error { + err := clientB.Receive(context.Background(), func(msg *sigProto.Message) error { payloadReceivedOnB = msg.GetBody().GetPayload() featuresSupportedReceivedOnB = msg.GetBody().GetFeaturesSupported() err := clientB.Send(&sigProto.Message{ @@ -122,7 +122,7 @@ var _ = Describe("GrpcClient", func() { key, _ := wgtypes.GenerateKey() client := createSignalClient(addr, key) go func() { - err := client.Receive(func(msg *sigProto.Message) error { + err := client.Receive(context.Background(), func(msg *sigProto.Message) error { return nil }) if err != nil { diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7c4535e28..c6f03ec86 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -126,9 +126,9 @@ func defaultBackoff(ctx context.Context) backoff.BackOff { // The messages will be handled by msgHandler function provided. // This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart) // The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller. -func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error { +func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error { - var backOff = defaultBackoff(c.ctx) + var backOff = defaultBackoff(ctx) operation := func() error { @@ -139,13 +139,13 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error { if connState == connectivity.Shutdown { return backoff.Permanent(fmt.Errorf("connection to signal has been shut down")) } else if !(connState == connectivity.Ready || connState == connectivity.Idle) { - c.signalConn.WaitForStateChange(c.ctx, connState) + c.signalConn.WaitForStateChange(ctx, connState) return fmt.Errorf("connection to signal is not ready and in %s state", connState) } // connect to Signal stream identifying ourselves with a public WireGuard key // todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management) - ctx, cancelStream := context.WithCancel(c.ctx) + ctx, cancelStream := context.WithCancel(ctx) defer cancelStream() stream, err := c.connect(ctx, c.key.PublicKey().String()) if err != nil { diff --git a/signal/client/mock.go b/signal/client/mock.go index a0ce13aed..70ecea9ed 100644 --- a/signal/client/mock.go +++ b/signal/client/mock.go @@ -1,6 +1,8 @@ package client import ( + "context" + "github.com/netbirdio/netbird/signal/proto" ) @@ -10,7 +12,7 @@ type MockClient struct { StreamConnectedFunc func() bool ReadyFunc func() bool WaitStreamConnectedFunc func() - ReceiveFunc func(msgHandler func(msg *proto.Message) error) error + ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error SendToStreamFunc func(msg *proto.EncryptedMessage) error SendFunc func(msg *proto.Message) error } @@ -54,11 +56,11 @@ func (sm *MockClient) WaitStreamConnected() { sm.WaitStreamConnectedFunc() } -func (sm *MockClient) Receive(msgHandler func(msg *proto.Message) error) error { +func (sm *MockClient) Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error { if sm.ReceiveFunc == nil { return nil } - return sm.ReceiveFunc(msgHandler) + return sm.ReceiveFunc(ctx, msgHandler) } func (sm *MockClient) SendToStream(msg *proto.EncryptedMessage) error {