Wait on daemon down (#2279)

This commit is contained in:
pascal-fischer 2024-07-17 16:26:06 +02:00 committed by GitHub
parent 4fad0e521f
commit 95d725f2c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 159 additions and 59 deletions

View File

@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@ -266,8 +266,23 @@ func (e *Engine) Stop() error {
e.close()
e.wgConnWorker.Wait()
log.Infof("stopped Netbird Engine")
return nil
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
for {
if !e.IsWGIfaceUp() {
log.Infof("stopped Netbird Engine")
return nil
}
select {
case <-timeout:
return fmt.Errorf("timeout when waiting for interface shutdown")
default:
time.Sleep(100 * time.Millisecond)
}
}
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
@ -1533,3 +1548,20 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}
func (e *Engine) IsWGIfaceUp() bool {
if e == nil || e.wgInterface == nil {
return false
}
iface, err := net.InterfaceByName(e.wgInterface.Name())
if err != nil {
log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
return false
}
if iface.Flags&net.FlagUp != 0 {
return true
}
return false
}

View File

@ -4,6 +4,7 @@ package networkmonitor
import (
"context"
"errors"
"fmt"
"syscall"
"unsafe"
@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("failed to open routing socket: %v", err)
}
defer func() {
if err := unix.Close(fd); err != nil {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err)
}
}()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
}
}()
for {
select {
case <-ctx.Done():
@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {

View File

@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
}
// Down engine work in the daemon.
func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
return &proto.DownResponse{}, nil
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
engine := s.connectClient.Engine()
for {
if !engine.IsWGIfaceUp() {
return &proto.DownResponse{}, nil
}
select {
case <-ctx.Done():
return &proto.DownResponse{}, nil
case <-timeout:
return nil, fmt.Errorf("failed to shut down properly")
default:
time.Sleep(100 * time.Millisecond)
}
}
}
// Status returns the daemon status

View File

@ -2,7 +2,6 @@ package client
import (
"context"
"crypto/tls"
"fmt"
"io"
"sync"
@ -11,15 +10,11 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
@ -51,26 +46,21 @@ type GrpcClient struct {
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
var conn *grpc.ClientConn
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
}
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed creating connection to Management Service %v", err)
log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err
}
@ -326,25 +316,41 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
var resp *proto.EncryptedMessage
operation := func() error {
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
defer cancel()
var err error
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
log.Printf("Login error: %v", err)
return err
}
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt registration message: %s", err)
log.Errorf("failed to decrypt login response: %s", err)
return nil, err
}

View File

@ -2,7 +2,6 @@ package client
import (
"context"
"crypto/tls"
"fmt"
"io"
"sync"
@ -14,9 +13,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
@ -64,28 +60,21 @@ func (c *GrpcClient) Close() error {
// NewClient creates a new Signal client
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
var conn *grpc.ClientConn
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
}
sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
sigCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed to connect to the signalling server %v", err)
log.Errorf("failed to connect to the signalling server: %v", err)
return nil, err
}
@ -408,7 +397,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
if err != nil {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
//todo send something??
// todo send something??
}
}
}

View File

@ -2,12 +2,18 @@ package grpc
import (
"context"
"crypto/tls"
"net"
"os/user"
"runtime"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
nbnet "github.com/netbirdio/netbird/util/net"
)
@ -35,3 +41,40 @@ func WithCustomDialer() grpc.DialOption {
return conn, nil
})
}
// grpcDialBackoff is the backoff mechanism for the grpc calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second
b.Clock = backoff.SystemClock
return backoff.WithContext(b, ctx)
}
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, err
}
return conn, nil
}