diff --git a/client/cmd/down.go b/client/cmd/down.go index 1837b13da..4d9f1eba4 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -26,7 +26,7 @@ var downCmd = &cobra.Command{ return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) defer cancel() conn, err := DialClientGRPCServer(ctx, daemonAddr) diff --git a/client/internal/engine.go b/client/internal/engine.go index 21a765a96..9e275c007 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -266,8 +266,23 @@ func (e *Engine) Stop() error { e.close() e.wgConnWorker.Wait() - log.Infof("stopped Netbird Engine") - return nil + + maxWaitTime := 5 * time.Second + timeout := time.After(maxWaitTime) + + for { + if !e.IsWGIfaceUp() { + log.Infof("stopped Netbird Engine") + return nil + } + + select { + case <-timeout: + return fmt.Errorf("timeout when waiting for interface shutdown") + default: + time.Sleep(100 * time.Millisecond) + } + } } // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services @@ -1533,3 +1548,20 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { return slices.Equal(checks.Files, oChecks.Files) }) } + +func (e *Engine) IsWGIfaceUp() bool { + if e == nil || e.wgInterface == nil { + return false + } + iface, err := net.InterfaceByName(e.wgInterface.Name()) + if err != nil { + log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err) + return false + } + + if iface.Flags&net.FlagUp != 0 { + return true + } + + return false +} diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go index 8d6ccd51b..29df7ea7f 100644 --- a/client/internal/networkmonitor/monitor_bsd.go +++ b/client/internal/networkmonitor/monitor_bsd.go @@ -4,6 +4,7 @@ package networkmonitor import ( "context" + "errors" "fmt" "syscall" "unsafe" @@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca return fmt.Errorf("failed to open routing socket: %v", err) } defer func() { - if err := unix.Close(fd); err != nil { + err := unix.Close(fd) + if err != nil && !errors.Is(err, unix.EBADF) { log.Errorf("Network monitor: failed to close routing socket: %v", err) } }() + go func() { + <-ctx.Done() + err := unix.Close(fd) + if err != nil && !errors.Is(err, unix.EBADF) { + log.Debugf("Network monitor: closed routing socket") + } + }() + for { select { case <-ctx.Done(): @@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca buf := make([]byte, 2048) n, err := unix.Read(fd, buf) if err != nil { - log.Errorf("Network monitor: failed to read from routing socket: %v", err) + if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { + log.Errorf("Network monitor: failed to read from routing socket: %v", err) + } continue } if n < unix.SizeofRtMsghdr { diff --git a/client/server/server.go b/client/server/server.go index 2805c10f4..8173d0741 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes } // Down engine work in the daemon. -func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { +func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo state := internal.CtxGetState(s.rootCtx) state.Set(internal.StatusIdle) - return &proto.DownResponse{}, nil + maxWaitTime := 5 * time.Second + timeout := time.After(maxWaitTime) + + engine := s.connectClient.Engine() + + for { + if !engine.IsWGIfaceUp() { + return &proto.DownResponse{}, nil + } + + select { + case <-ctx.Done(): + return &proto.DownResponse{}, nil + case <-timeout: + return nil, fmt.Errorf("failed to shut down properly") + default: + time.Sleep(100 * time.Millisecond) + } + } } // Status returns the daemon status diff --git a/management/client/grpc.go b/management/client/grpc.go index a8f4a91c7..568c15313 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -2,7 +2,6 @@ package client import ( "context" - "crypto/tls" "fmt" "io" "sync" @@ -11,15 +10,11 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" - - "github.com/cenkalti/backoff/v4" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" @@ -51,26 +46,21 @@ type GrpcClient struct { // NewClient creates a new client to Management service func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { - transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) + var conn *grpc.ClientConn - if tlsEnabled { - transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + operation := func() error { + var err error + conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + if err != nil { + log.Printf("createConnection error: %v", err) + return err + } + return nil } - mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) - defer cancel() - conn, err := grpc.DialContext( - mgmCtx, - addr, - transportOption, - nbgrpc.WithCustomDialer(), - grpc.WithBlock(), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 30 * time.Second, - Timeout: 10 * time.Second, - })) + err := backoff.Retry(operation, nbgrpc.Backoff(ctx)) if err != nil { - log.Errorf("failed creating connection to Management Service %v", err) + log.Errorf("failed creating connection to Management Service: %v", err) return nil, err } @@ -326,25 +316,41 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro if !c.ready() { return nil, fmt.Errorf(errMsgNoMgmtConnection) } + loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) if err != nil { log.Errorf("failed to encrypt message: %s", err) return nil, err } - mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout) - defer cancel() - resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ - WgPubKey: c.key.PublicKey().String(), - Body: loginReq, - }) + + var resp *proto.EncryptedMessage + operation := func() error { + mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout) + defer cancel() + + var err error + resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ + WgPubKey: c.key.PublicKey().String(), + Body: loginReq, + }) + if err != nil { + log.Printf("Login error: %v", err) + return err + } + + return nil + } + + err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx)) if err != nil { + log.Errorf("failed to login to Management Service: %v", err) return nil, err } loginResp := &proto.LoginResponse{} err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) if err != nil { - log.Errorf("failed to decrypt registration message: %s", err) + log.Errorf("failed to decrypt login response: %s", err) return nil, err } diff --git a/signal/client/grpc.go b/signal/client/grpc.go index c6f03ec86..7a3b502ff 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -2,7 +2,6 @@ package client import ( "context" - "crypto/tls" "fmt" "io" "sync" @@ -14,9 +13,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -64,28 +60,21 @@ func (c *GrpcClient) Close() error { // NewClient creates a new Signal client func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { + var conn *grpc.ClientConn - transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) - - if tlsEnabled { - transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + operation := func() error { + var err error + conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + if err != nil { + log.Printf("createConnection error: %v", err) + return err + } + return nil } - sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout) - defer cancel() - conn, err := grpc.DialContext( - sigCtx, - addr, - transportOption, - nbgrpc.WithCustomDialer(), - grpc.WithBlock(), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 30 * time.Second, - Timeout: 10 * time.Second, - })) - + err := backoff.Retry(operation, nbgrpc.Backoff(ctx)) if err != nil { - log.Errorf("failed to connect to the signalling server %v", err) + log.Errorf("failed to connect to the signalling server: %v", err) return nil, err } @@ -408,7 +397,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient, if err != nil { log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error()) - //todo send something?? + // todo send something?? } } } diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 3fba0c84e..57ab8fd55 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -2,12 +2,18 @@ package grpc import ( "context" + "crypto/tls" "net" "os/user" "runtime" + "time" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -35,3 +41,40 @@ func WithCustomDialer() grpc.DialOption { return conn, nil }) } + +// grpcDialBackoff is the backoff mechanism for the grpc calls +func Backoff(ctx context.Context) backoff.BackOff { + b := backoff.NewExponentialBackOff() + b.MaxElapsedTime = 10 * time.Second + b.Clock = backoff.SystemClock + return backoff.WithContext(b, ctx) +} + +func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { + transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) + + if tlsEnabled { + transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + } + + connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + conn, err := grpc.DialContext( + connCtx, + addr, + transportOption, + WithCustomDialer(), + grpc.WithBlock(), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 30 * time.Second, + Timeout: 10 * time.Second, + }), + ) + if err != nil { + log.Printf("DialContext error: %v", err) + return nil, err + } + + return conn, nil +}