Wait on daemon down (#2279)

This commit is contained in:
pascal-fischer 2024-07-17 16:26:06 +02:00 committed by GitHub
parent 4fad0e521f
commit 95d725f2c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 159 additions and 59 deletions

View File

@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@ -266,8 +266,23 @@ func (e *Engine) Stop() error {
e.close() e.close()
e.wgConnWorker.Wait() e.wgConnWorker.Wait()
log.Infof("stopped Netbird Engine")
return nil maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
for {
if !e.IsWGIfaceUp() {
log.Infof("stopped Netbird Engine")
return nil
}
select {
case <-timeout:
return fmt.Errorf("timeout when waiting for interface shutdown")
default:
time.Sleep(100 * time.Millisecond)
}
}
} }
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
@ -1533,3 +1548,20 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files) return slices.Equal(checks.Files, oChecks.Files)
}) })
} }
func (e *Engine) IsWGIfaceUp() bool {
if e == nil || e.wgInterface == nil {
return false
}
iface, err := net.InterfaceByName(e.wgInterface.Name())
if err != nil {
log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
return false
}
if iface.Flags&net.FlagUp != 0 {
return true
}
return false
}

View File

@ -4,6 +4,7 @@ package networkmonitor
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"syscall" "syscall"
"unsafe" "unsafe"
@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("failed to open routing socket: %v", err) return fmt.Errorf("failed to open routing socket: %v", err)
} }
defer func() { defer func() {
if err := unix.Close(fd); err != nil { err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err) log.Errorf("Network monitor: failed to close routing socket: %v", err)
} }
}() }()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
}
}()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, err := unix.Read(fd, buf) n, err := unix.Read(fd, buf)
if err != nil { if err != nil {
log.Errorf("Network monitor: failed to read from routing socket: %v", err) if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
}
continue continue
} }
if n < unix.SizeofRtMsghdr { if n < unix.SizeofRtMsghdr {

View File

@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
} }
// Down engine work in the daemon. // Down engine work in the daemon.
func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle) state.Set(internal.StatusIdle)
return &proto.DownResponse{}, nil maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
engine := s.connectClient.Engine()
for {
if !engine.IsWGIfaceUp() {
return &proto.DownResponse{}, nil
}
select {
case <-ctx.Done():
return &proto.DownResponse{}, nil
case <-timeout:
return nil, fmt.Errorf("failed to shut down properly")
default:
time.Sleep(100 * time.Millisecond)
}
}
} }
// Status returns the daemon status // Status returns the daemon status

View File

@ -2,7 +2,6 @@ package client
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"io" "io"
"sync" "sync"
@ -11,15 +10,11 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
@ -51,26 +46,21 @@ type GrpcClient struct {
// NewClient creates a new client to Management service // NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) var conn *grpc.ClientConn
if tlsEnabled { operation := func() error {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
} }
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil { if err != nil {
log.Errorf("failed creating connection to Management Service %v", err) log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err return nil, err
} }
@ -326,25 +316,41 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
if !c.ready() { if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection) return nil, fmt.Errorf(errMsgNoMgmtConnection)
} }
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil { if err != nil {
log.Errorf("failed to encrypt message: %s", err) log.Errorf("failed to encrypt message: %s", err)
return nil, err return nil, err
} }
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel() var resp *proto.EncryptedMessage
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ operation := func() error {
WgPubKey: c.key.PublicKey().String(), mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
Body: loginReq, defer cancel()
})
var err error
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
log.Printf("Login error: %v", err)
return err
}
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
if err != nil { if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err return nil, err
} }
loginResp := &proto.LoginResponse{} loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil { if err != nil {
log.Errorf("failed to decrypt registration message: %s", err) log.Errorf("failed to decrypt login response: %s", err)
return nil, err return nil, err
} }

View File

@ -2,7 +2,6 @@ package client
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"io" "io"
"sync" "sync"
@ -14,9 +13,6 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -64,28 +60,21 @@ func (c *GrpcClient) Close() error {
// NewClient creates a new Signal client // NewClient creates a new Signal client
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
var conn *grpc.ClientConn
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) operation := func() error {
var err error
if tlsEnabled { conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
} }
sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout) err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
defer cancel()
conn, err := grpc.DialContext(
sigCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil { if err != nil {
log.Errorf("failed to connect to the signalling server %v", err) log.Errorf("failed to connect to the signalling server: %v", err)
return nil, err return nil, err
} }
@ -408,7 +397,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
if err != nil { if err != nil {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error()) log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
//todo send something?? // todo send something??
} }
} }
} }

View File

@ -2,12 +2,18 @@ package grpc
import ( import (
"context" "context"
"crypto/tls"
"net" "net"
"os/user" "os/user"
"runtime" "runtime"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@ -35,3 +41,40 @@ func WithCustomDialer() grpc.DialOption {
return conn, nil return conn, nil
}) })
} }
// grpcDialBackoff is the backoff mechanism for the grpc calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second
b.Clock = backoff.SystemClock
return backoff.WithContext(b, ctx)
}
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, err
}
return conn, nil
}