1
0
mirror of https://github.com/netbirdio/netbird.git synced 2025-07-11 11:55:48 +02:00

[client] Destory WG interface on down timeout ()

wait on engine down to not only wait for the interface to be down but completely removed. If the waiting loop reaches the timeout we will trigger an interface destroy. On the up command, it now waits until the engine is fully running before sending the response to the CLI. Includes a small refactor of probes to comply with sonar rules about parameter count in the function call
This commit is contained in:
pascal-fischer
2024-09-02 19:19:14 +02:00
committed by GitHub
parent 95174d4619
commit 13e7198046
16 changed files with 222 additions and 116 deletions

@ -42,6 +42,8 @@ var downCmd = &cobra.Command{
log.Errorf("call service down method: %v", err)
return err
}
cmd.Println("Disconnected")
return nil
},
}

@ -55,17 +55,15 @@ func NewConnectClient(
// Run with main logic.
func (c *ConnectClient) Run() error {
return c.run(MobileDependency{}, nil, nil, nil, nil)
return c.run(MobileDependency{}, nil, nil)
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(
mgmProbe *Probe,
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
probes *ProbeHolder,
runningWg *sync.WaitGroup,
) error {
return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
return c.run(MobileDependency{}, probes, runningWg)
}
// RunOnAndroid with main logic on mobile system
@ -84,7 +82,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, nil, nil, nil, nil)
return c.run(mobileDependency, nil, nil)
}
func (c *ConnectClient) RunOniOS(
@ -100,15 +98,13 @@ func (c *ConnectClient) RunOniOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
return c.run(mobileDependency, nil, nil, nil, nil)
return c.run(mobileDependency, nil, nil)
}
func (c *ConnectClient) run(
mobileDependency MobileDependency,
mgmProbe *Probe,
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
probes *ProbeHolder,
runningWg *sync.WaitGroup,
) error {
defer func() {
if r := recover(); r != nil {
@ -255,7 +251,7 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
err = c.engine.Start()
@ -267,17 +263,15 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningWg != nil {
runningWg.Done()
}
<-engineCtx.Done()
c.statusRecorder.ClientTeardown()
backOff.Reset()
err = c.engine.Stop()
if err != nil {
log.Errorf("failed stopping engine %v", err)
return wrapErr(err)
}
log.Info("stopped NetBird client")
if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
@ -307,6 +301,12 @@ func (c *ConnectClient) Engine() *Engine {
return e
}
func (c *ConnectClient) Stop() error {
c.engineMutex.Lock()
defer c.engineMutex.Unlock()
return c.engine.Stop()
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false

@ -155,10 +155,7 @@ type Engine struct {
dnsServer dns.Server
mgmProbe *Probe
signalProbe *Probe
relayProbe *Probe
wgProbe *Probe
probes *ProbeHolder
wgConnWorker sync.WaitGroup
@ -192,9 +189,6 @@ func NewEngine(
mobileDep,
statusRecorder,
nil,
nil,
nil,
nil,
checks,
)
}
@ -208,10 +202,7 @@ func NewEngineWithProbes(
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
mgmProbe *Probe,
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
probes *ProbeHolder,
checks []*mgmProto.Checks,
) *Engine {
@ -229,22 +220,20 @@ func NewEngineWithProbes(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
mgmProbe: mgmProbe,
signalProbe: signalProbe,
relayProbe: relayProbe,
wgProbe: wgProbe,
probes: probes,
checks: checks,
}
}
func (e *Engine) Stop() error {
if e == nil {
// this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started
log.Debugf("tried stopping engine that is nil")
return nil
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.cancel != nil {
e.cancel()
}
// stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil {
e.networkMonitor.Stop()
@ -260,29 +249,21 @@ func (e *Engine) Stop() error {
e.clientRoutes = nil
e.clientRoutesMu.Unlock()
if e.cancel != nil {
e.cancel()
}
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
e.wgConnWorker.Wait()
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
log.Infof("Engine stopped")
for {
if !e.IsWGIfaceUp() {
log.Infof("stopped Netbird Engine")
return nil
}
select {
case <-timeout:
return fmt.Errorf("timeout when waiting for interface shutdown")
default:
time.Sleep(100 * time.Millisecond)
}
}
return nil
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
@ -1415,24 +1396,27 @@ func (e *Engine) getRosenpassAddr() string {
}
func (e *Engine) receiveProbeEvents() {
if e.signalProbe != nil {
go e.signalProbe.Receive(e.ctx, func() bool {
if e.probes == nil {
return
}
if e.probes.SignalProbe != nil {
go e.probes.SignalProbe.Receive(e.ctx, func() bool {
healthy := e.signal.IsHealthy()
log.Debugf("received signal probe request, healthy: %t", healthy)
return healthy
})
}
if e.mgmProbe != nil {
go e.mgmProbe.Receive(e.ctx, func() bool {
if e.probes.MgmProbe != nil {
go e.probes.MgmProbe.Receive(e.ctx, func() bool {
healthy := e.mgmClient.IsHealthy()
log.Debugf("received management probe request, healthy: %t", healthy)
return healthy
})
}
if e.relayProbe != nil {
go e.relayProbe.Receive(e.ctx, func() bool {
if e.probes.RelayProbe != nil {
go e.probes.RelayProbe.Receive(e.ctx, func() bool {
healthy := true
results := append(e.probeSTUNs(), e.probeTURNs()...)
@ -1451,8 +1435,8 @@ func (e *Engine) receiveProbeEvents() {
})
}
if e.wgProbe != nil {
go e.wgProbe.Receive(e.ctx, func() bool {
if e.probes.WgProbe != nil {
go e.probes.WgProbe.Receive(e.ctx, func() bool {
log.Debug("received wg probe request")
for _, peer := range e.peerConns {
@ -1548,20 +1532,3 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}
func (e *Engine) IsWGIfaceUp() bool {
if e == nil || e.wgInterface == nil {
return false
}
iface, err := net.InterfaceByName(e.wgInterface.Name())
if err != nil {
log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
return false
}
if iface.Flags&net.FlagUp != 0 {
return true
}
return false
}

@ -2,6 +2,13 @@ package internal
import "context"
type ProbeHolder struct {
MgmProbe *Probe
SignalProbe *Probe
RelayProbe *Probe
WgProbe *Probe
}
// Probe allows to run on-demand callbacks from different code locations.
// Pass the probe to a receiving and a sending end. The receiving end starts listening
// to requests with Receive and executes a callback when the sending end requests it

@ -12,7 +12,6 @@ import (
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
@ -143,10 +142,14 @@ func (s *Server) Start() error {
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
}
runningWg := sync.WaitGroup{}
runningWg.Add(1)
if !config.DisableAutoConnect {
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, &runningWg)
}
runningWg.Wait()
return nil
}
@ -154,7 +157,7 @@ func (s *Server) Start() error {
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
runningWg *sync.WaitGroup,
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
@ -185,7 +188,15 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
err := s.connectClient.RunWithProbes(mgmProbe, signalProbe, relayProbe, wgProbe)
probes := internal.ProbeHolder{
MgmProbe: s.mgmProbe,
SignalProbe: s.signalProbe,
RelayProbe: s.relayProbe,
WgProbe: s.wgProbe,
}
err := s.connectClient.RunWithProbes(&probes, runningWg)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
}
@ -576,7 +587,11 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
runningWg := sync.WaitGroup{}
runningWg.Add(1)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, &runningWg)
runningWg.Wait()
return &proto.UpResponse{}, nil
}
@ -590,28 +605,19 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
return nil, fmt.Errorf("service is not up")
}
s.actCancel()
err := s.connectClient.Stop()
if err != nil {
log.Errorf("failed to shut down properly: %v", err)
return nil, err
}
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
log.Infof("service is down")
engine := s.connectClient.Engine()
for {
if !engine.IsWGIfaceUp() {
return &proto.DownResponse{}, nil
}
select {
case <-ctx.Done():
return &proto.DownResponse{}, nil
case <-timeout:
return nil, fmt.Errorf("failed to shut down properly")
default:
time.Sleep(100 * time.Millisecond)
}
}
return &proto.DownResponse{}, nil
}
// Status returns the daemon status

@ -73,7 +73,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}

@ -124,7 +124,23 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.Close()
err := w.tun.Close()
if err != nil {
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
}
err = w.waitUntilRemoved()
if err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
err = w.Destroy()
if err != nil {
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
}
log.Infof("interface %s successfully removed", w.Name())
}
return nil
}
// SetFilter sets packet filters for the userspace implementation
@ -163,3 +179,30 @@ func (w *WGIface) GetDevice() *DeviceWrapper {
func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
return w.configurer.getStats(peerKey)
}
func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime)
defer timeout.Stop()
for {
iface, err := net.InterfaceByName(w.Name())
if err != nil {
if _, ok := err.(*net.OpError); ok {
log.Infof("interface %s has been removed", w.Name())
return nil
}
log.Debugf("failed to get interface by name %s: %v", w.Name(), err)
} else if iface == nil {
log.Infof("interface %s has been removed", w.Name())
return nil
}
select {
case <-timeout.C:
return fmt.Errorf("timeout when waiting for interface %s to be removed", w.Name())
default:
time.Sleep(100 * time.Millisecond)
}
}
}

@ -0,0 +1,17 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package iface
import (
"fmt"
"os/exec"
)
func (w *WGIface) Destroy() error {
out, err := exec.Command("ifconfig", w.Name(), "destroy").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}

@ -0,0 +1,22 @@
//go:build linux && !android
package iface
import (
"fmt"
"github.com/vishvananda/netlink"
)
func (w *WGIface) Destroy() error {
link, err := netlink.LinkByName(w.Name())
if err != nil {
return fmt.Errorf("failed to get link by name %s: %w", w.Name(), err)
}
if err := netlink.LinkDel(link); err != nil {
return fmt.Errorf("failed to delete link %s: %w", w.Name(), err)
}
return nil
}

@ -0,0 +1,9 @@
//go:build android || (ios && !darwin)
package iface
import "errors"
func (w *WGIface) Destroy() error {
return errors.New("not supported on mobile")
}

@ -0,0 +1,32 @@
//go:build windows
package iface
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (w *WGIface) Destroy() error {
netshCmd := GetSystem32Command("netsh")
out, err := exec.Command(netshCmd, "interface", "set", "interface", w.Name(), "admin=disable").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
func GetSystem32Command(command string) string {
_, err := exec.LookPath(command)
if err == nil {
return command
}
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
return "C:\\windows\\system32\\" + command + ".exe"
}

@ -3,6 +3,7 @@
package iface
import (
"fmt"
"os/exec"
"github.com/pion/transport/v3"
@ -41,7 +42,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
func (t *tunDevice) Create() (wgConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunDevice)
@ -55,7 +56,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, err
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
@ -63,7 +64,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
if err != nil {
t.device.Close()
t.configurer.close()
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}

@ -70,7 +70,7 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) {
configurer := newWGConfigurer(t.name)
if err := configurer.configureInterface(t.key, t.wgPort); err != nil {
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return configurer, nil

@ -47,7 +47,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunIface)
@ -61,7 +61,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
err = t.configurer.configureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
log.Debugf("device has been created: %s", t.name)

@ -48,8 +48,8 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
log.Debugf("failed to create tun unterface (%s, %d): %s", t.name, t.mtu, err)
return nil, err
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunIface)
@ -63,7 +63,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, err
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
@ -71,7 +71,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
if err != nil {
t.device.Close()
t.configurer.close()
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}

@ -59,7 +59,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.nativeTunDevice = tunDevice.(*tun.NativeTun)
t.wrapper = newDeviceWrapper(tunDevice)
@ -89,7 +89,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, err
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
@ -97,7 +97,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
if err != nil {
t.device.Close()
t.configurer.close()
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}