[client] Destory WG interface on down timeout (#2435)

wait on engine down to not only wait for the interface to be down but completely removed. If the waiting loop reaches the timeout we will trigger an interface destroy. On the up command, it now waits until the engine is fully running before sending the response to the CLI. Includes a small refactor of probes to comply with sonar rules about parameter count in the function call
This commit is contained in:
pascal-fischer 2024-09-02 19:19:14 +02:00 committed by GitHub
parent 95174d4619
commit 13e7198046
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 222 additions and 116 deletions

View File

@ -42,6 +42,8 @@ var downCmd = &cobra.Command{
log.Errorf("call service down method: %v", err)
return err
}
cmd.Println("Disconnected")
return nil
},
}

View File

@ -55,17 +55,15 @@ func NewConnectClient(
// Run with main logic.
func (c *ConnectClient) Run() error {
return c.run(MobileDependency{}, nil, nil, nil, nil)
return c.run(MobileDependency{}, nil, nil)
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(
mgmProbe *Probe,
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
probes *ProbeHolder,
runningWg *sync.WaitGroup,
) error {
return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
return c.run(MobileDependency{}, probes, runningWg)
}
// RunOnAndroid with main logic on mobile system
@ -84,7 +82,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, nil, nil, nil, nil)
return c.run(mobileDependency, nil, nil)
}
func (c *ConnectClient) RunOniOS(
@ -100,15 +98,13 @@ func (c *ConnectClient) RunOniOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
return c.run(mobileDependency, nil, nil, nil, nil)
return c.run(mobileDependency, nil, nil)
}
func (c *ConnectClient) run(
mobileDependency MobileDependency,
mgmProbe *Probe,
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
probes *ProbeHolder,
runningWg *sync.WaitGroup,
) error {
defer func() {
if r := recover(); r != nil {
@ -255,7 +251,7 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
err = c.engine.Start()
@ -267,17 +263,15 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningWg != nil {
runningWg.Done()
}
<-engineCtx.Done()
c.statusRecorder.ClientTeardown()
backOff.Reset()
err = c.engine.Stop()
if err != nil {
log.Errorf("failed stopping engine %v", err)
return wrapErr(err)
}
log.Info("stopped NetBird client")
if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
@ -307,6 +301,12 @@ func (c *ConnectClient) Engine() *Engine {
return e
}
func (c *ConnectClient) Stop() error {
c.engineMutex.Lock()
defer c.engineMutex.Unlock()
return c.engine.Stop()
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false

View File

@ -155,10 +155,7 @@ type Engine struct {
dnsServer dns.Server
mgmProbe *Probe
signalProbe *Probe
relayProbe *Probe
wgProbe *Probe
probes *ProbeHolder
wgConnWorker sync.WaitGroup
@ -192,9 +189,6 @@ func NewEngine(
mobileDep,
statusRecorder,
nil,
nil,
nil,
nil,
checks,
)
}
@ -208,10 +202,7 @@ func NewEngineWithProbes(
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
mgmProbe *Probe,
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
probes *ProbeHolder,
checks []*mgmProto.Checks,
) *Engine {
@ -229,22 +220,20 @@ func NewEngineWithProbes(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
mgmProbe: mgmProbe,
signalProbe: signalProbe,
relayProbe: relayProbe,
wgProbe: wgProbe,
probes: probes,
checks: checks,
}
}
func (e *Engine) Stop() error {
if e == nil {
// this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started
log.Debugf("tried stopping engine that is nil")
return nil
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.cancel != nil {
e.cancel()
}
// stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil {
e.networkMonitor.Stop()
@ -260,29 +249,21 @@ func (e *Engine) Stop() error {
e.clientRoutes = nil
e.clientRoutesMu.Unlock()
if e.cancel != nil {
e.cancel()
}
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
e.wgConnWorker.Wait()
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
log.Infof("Engine stopped")
for {
if !e.IsWGIfaceUp() {
log.Infof("stopped Netbird Engine")
return nil
}
select {
case <-timeout:
return fmt.Errorf("timeout when waiting for interface shutdown")
default:
time.Sleep(100 * time.Millisecond)
}
}
return nil
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
@ -1415,24 +1396,27 @@ func (e *Engine) getRosenpassAddr() string {
}
func (e *Engine) receiveProbeEvents() {
if e.signalProbe != nil {
go e.signalProbe.Receive(e.ctx, func() bool {
if e.probes == nil {
return
}
if e.probes.SignalProbe != nil {
go e.probes.SignalProbe.Receive(e.ctx, func() bool {
healthy := e.signal.IsHealthy()
log.Debugf("received signal probe request, healthy: %t", healthy)
return healthy
})
}
if e.mgmProbe != nil {
go e.mgmProbe.Receive(e.ctx, func() bool {
if e.probes.MgmProbe != nil {
go e.probes.MgmProbe.Receive(e.ctx, func() bool {
healthy := e.mgmClient.IsHealthy()
log.Debugf("received management probe request, healthy: %t", healthy)
return healthy
})
}
if e.relayProbe != nil {
go e.relayProbe.Receive(e.ctx, func() bool {
if e.probes.RelayProbe != nil {
go e.probes.RelayProbe.Receive(e.ctx, func() bool {
healthy := true
results := append(e.probeSTUNs(), e.probeTURNs()...)
@ -1451,8 +1435,8 @@ func (e *Engine) receiveProbeEvents() {
})
}
if e.wgProbe != nil {
go e.wgProbe.Receive(e.ctx, func() bool {
if e.probes.WgProbe != nil {
go e.probes.WgProbe.Receive(e.ctx, func() bool {
log.Debug("received wg probe request")
for _, peer := range e.peerConns {
@ -1548,20 +1532,3 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}
func (e *Engine) IsWGIfaceUp() bool {
if e == nil || e.wgInterface == nil {
return false
}
iface, err := net.InterfaceByName(e.wgInterface.Name())
if err != nil {
log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
return false
}
if iface.Flags&net.FlagUp != 0 {
return true
}
return false
}

View File

@ -2,6 +2,13 @@ package internal
import "context"
type ProbeHolder struct {
MgmProbe *Probe
SignalProbe *Probe
RelayProbe *Probe
WgProbe *Probe
}
// Probe allows to run on-demand callbacks from different code locations.
// Pass the probe to a receiving and a sending end. The receiving end starts listening
// to requests with Receive and executes a callback when the sending end requests it

View File

@ -12,7 +12,6 @@ import (
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
@ -143,10 +142,14 @@ func (s *Server) Start() error {
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
}
runningWg := sync.WaitGroup{}
runningWg.Add(1)
if !config.DisableAutoConnect {
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, &runningWg)
}
runningWg.Wait()
return nil
}
@ -154,7 +157,7 @@ func (s *Server) Start() error {
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
runningWg *sync.WaitGroup,
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
@ -185,7 +188,15 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
err := s.connectClient.RunWithProbes(mgmProbe, signalProbe, relayProbe, wgProbe)
probes := internal.ProbeHolder{
MgmProbe: s.mgmProbe,
SignalProbe: s.signalProbe,
RelayProbe: s.relayProbe,
WgProbe: s.wgProbe,
}
err := s.connectClient.RunWithProbes(&probes, runningWg)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
}
@ -576,7 +587,11 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
runningWg := sync.WaitGroup{}
runningWg.Add(1)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, &runningWg)
runningWg.Wait()
return &proto.UpResponse{}, nil
}
@ -590,28 +605,19 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
return nil, fmt.Errorf("service is not up")
}
s.actCancel()
err := s.connectClient.Stop()
if err != nil {
log.Errorf("failed to shut down properly: %v", err)
return nil, err
}
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
log.Infof("service is down")
engine := s.connectClient.Engine()
for {
if !engine.IsWGIfaceUp() {
return &proto.DownResponse{}, nil
}
select {
case <-ctx.Done():
return &proto.DownResponse{}, nil
case <-timeout:
return nil, fmt.Errorf("failed to shut down properly")
default:
time.Sleep(100 * time.Millisecond)
}
}
return &proto.DownResponse{}, nil
}
// Status returns the daemon status

View File

@ -73,7 +73,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}

View File

@ -124,7 +124,23 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.Close()
err := w.tun.Close()
if err != nil {
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
}
err = w.waitUntilRemoved()
if err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
err = w.Destroy()
if err != nil {
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
}
log.Infof("interface %s successfully removed", w.Name())
}
return nil
}
// SetFilter sets packet filters for the userspace implementation
@ -163,3 +179,30 @@ func (w *WGIface) GetDevice() *DeviceWrapper {
func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
return w.configurer.getStats(peerKey)
}
func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime)
defer timeout.Stop()
for {
iface, err := net.InterfaceByName(w.Name())
if err != nil {
if _, ok := err.(*net.OpError); ok {
log.Infof("interface %s has been removed", w.Name())
return nil
}
log.Debugf("failed to get interface by name %s: %v", w.Name(), err)
} else if iface == nil {
log.Infof("interface %s has been removed", w.Name())
return nil
}
select {
case <-timeout.C:
return fmt.Errorf("timeout when waiting for interface %s to be removed", w.Name())
default:
time.Sleep(100 * time.Millisecond)
}
}
}

View File

@ -0,0 +1,17 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package iface
import (
"fmt"
"os/exec"
)
func (w *WGIface) Destroy() error {
out, err := exec.Command("ifconfig", w.Name(), "destroy").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}

View File

@ -0,0 +1,22 @@
//go:build linux && !android
package iface
import (
"fmt"
"github.com/vishvananda/netlink"
)
func (w *WGIface) Destroy() error {
link, err := netlink.LinkByName(w.Name())
if err != nil {
return fmt.Errorf("failed to get link by name %s: %w", w.Name(), err)
}
if err := netlink.LinkDel(link); err != nil {
return fmt.Errorf("failed to delete link %s: %w", w.Name(), err)
}
return nil
}

View File

@ -0,0 +1,9 @@
//go:build android || (ios && !darwin)
package iface
import "errors"
func (w *WGIface) Destroy() error {
return errors.New("not supported on mobile")
}

View File

@ -0,0 +1,32 @@
//go:build windows
package iface
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (w *WGIface) Destroy() error {
netshCmd := GetSystem32Command("netsh")
out, err := exec.Command(netshCmd, "interface", "set", "interface", w.Name(), "admin=disable").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
func GetSystem32Command(command string) string {
_, err := exec.LookPath(command)
if err == nil {
return command
}
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
return "C:\\windows\\system32\\" + command + ".exe"
}

View File

@ -3,6 +3,7 @@
package iface
import (
"fmt"
"os/exec"
"github.com/pion/transport/v3"
@ -41,7 +42,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
func (t *tunDevice) Create() (wgConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunDevice)
@ -55,7 +56,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, err
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
@ -63,7 +64,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
if err != nil {
t.device.Close()
t.configurer.close()
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}

View File

@ -70,7 +70,7 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) {
configurer := newWGConfigurer(t.name)
if err := configurer.configureInterface(t.key, t.wgPort); err != nil {
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return configurer, nil

View File

@ -47,7 +47,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunIface)
@ -61,7 +61,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
err = t.configurer.configureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
log.Debugf("device has been created: %s", t.name)

View File

@ -48,8 +48,8 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
log.Debugf("failed to create tun unterface (%s, %d): %s", t.name, t.mtu, err)
return nil, err
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.wrapper = newDeviceWrapper(tunIface)
@ -63,7 +63,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, err
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
@ -71,7 +71,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
if err != nil {
t.device.Close()
t.configurer.close()
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}

View File

@ -59,7 +59,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.nativeTunDevice = tunDevice.(*tun.NativeTun)
t.wrapper = newDeviceWrapper(tunDevice)
@ -89,7 +89,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, err
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = newWGUSPConfigurer(t.device, t.name)
@ -97,7 +97,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
if err != nil {
t.device.Close()
t.configurer.close()
return nil, err
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}