package client import ( "context" "crypto/tls" "fmt" "io" "sync" "time" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" "github.com/cenkalti/backoff/v4" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second // ConnStateNotifier is a wrapper interface of the status recorders type ConnStateNotifier interface { MarkManagementDisconnected(error) MarkManagementConnected() } type GrpcClient struct { key wgtypes.Key realClient proto.ManagementServiceClient ctx context.Context conn *grpc.ClientConn connStateCallback ConnStateNotifier connStateCallbackLock sync.RWMutex } // NewClient creates a new client to Management service func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) } mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) defer cancel() conn, err := grpc.DialContext( mgmCtx, addr, transportOption, nbgrpc.WithCustomDialer(), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, })) if err != nil { log.Errorf("failed creating connection to Management Service %v", err) return nil, err } realClient := proto.NewManagementServiceClient(conn) return &GrpcClient{ key: ourPrivateKey, realClient: realClient, ctx: ctx, conn: conn, connStateCallbackLock: sync.RWMutex{}, }, nil } // Close closes connection to the Management Service func (c *GrpcClient) Close() error { return c.conn.Close() } // SetConnStateListener set the ConnStateNotifier func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) { c.connStateCallbackLock.Lock() defer c.connStateCallbackLock.Unlock() c.connStateCallback = notifier } // defaultBackoff is a basic backoff mechanism for general issues func defaultBackoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, RandomizationFactor: 1, Multiplier: 1.7, MaxInterval: 10 * time.Second, MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months Stop: backoff.Stop, Clock: backoff.SystemClock, }, ctx) } // ready indicates whether the client is okay and ready to be used // for now it just checks whether gRPC connection to the service is ready func (c *GrpcClient) ready() bool { return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle } // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Blocking request. The result will be sent via msgHandler callback function func (c *GrpcClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error { backOff := defaultBackoff(ctx) operation := func() error { log.Debugf("management connection state %v", c.conn.GetState()) connState := c.conn.GetState() if connState == connectivity.Shutdown { return backoff.Permanent(fmt.Errorf("connection to management has been shut down")) } else if !(connState == connectivity.Ready || connState == connectivity.Idle) { c.conn.WaitForStateChange(ctx, connState) return fmt.Errorf("connection to management is not ready and in %s state", connState) } serverPubKey, err := c.GetServerPublicKey() if err != nil { log.Debugf("failed getting Management Service public key: %s", err) return err } ctx, cancelStream := context.WithCancel(ctx) defer cancelStream() stream, err := c.connectToStream(ctx, *serverPubKey) if err != nil { log.Debugf("failed to open Management Service stream: %s", err) if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied { return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer } return err } log.Infof("connected to the Management Service stream") c.notifyConnected() // blocking until error err = c.receiveEvents(stream, *serverPubKey, msgHandler) if err != nil { s, _ := gstatus.FromError(err) switch s.Code() { case codes.PermissionDenied: return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer case codes.Canceled: log.Debugf("management connection context has been canceled, this usually indicates shutdown") return nil default: backOff.Reset() // reset backoff counter after successful connection c.notifyDisconnected(err) log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) return err } } return nil } err := backoff.Retry(operation, backOff) if err != nil { log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err) return err } return nil } // GetNetworkMap return with the network map func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) { serverPubKey, err := c.GetServerPublicKey() if err != nil { log.Debugf("failed getting Management Service public key: %s", err) return nil, err } ctx, cancelStream := context.WithCancel(c.ctx) defer cancelStream() stream, err := c.connectToStream(ctx, *serverPubKey) if err != nil { log.Debugf("failed to open Management Service stream: %s", err) return nil, err } defer func() { _ = stream.CloseSend() }() update, err := stream.Recv() if err == io.EOF { log.Debugf("Management stream has been closed by server: %s", err) return nil, err } if err != nil { log.Debugf("disconnected from Management Service sync stream: %v", err) return nil, err } decryptedResp := &proto.SyncResponse{} err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp) if err != nil { log.Errorf("failed decrypting update message from Management Service: %s", err) return nil, err } if decryptedResp.GetNetworkMap() == nil { return nil, fmt.Errorf("invalid msg, required network map") } return decryptedResp.GetNetworkMap(), nil } func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) { req := &proto.SyncRequest{} myPrivateKey := c.key myPublicKey := myPrivateKey.PublicKey() encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req) if err != nil { log.Errorf("failed encrypting message: %s", err) return nil, err } syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq} sync, err := c.realClient.Sync(ctx, syncReq) if err != nil { return nil, err } return sync, nil } func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error { for { update, err := stream.Recv() if err == io.EOF { log.Debugf("Management stream has been closed by server: %s", err) return err } if err != nil { log.Debugf("disconnected from Management Service sync stream: %v", err) return err } log.Debugf("got an update message from Management Service") decryptedResp := &proto.SyncResponse{} err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp) if err != nil { log.Errorf("failed decrypting update message from Management Service: %s", err) return err } err = msgHandler(decryptedResp) if err != nil { log.Errorf("failed handling an update message received from Management Service: %v", err.Error()) return err } } } // GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server) func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) { if !c.ready() { return nil, fmt.Errorf("no connection to management") } mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) defer cancel() resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) if err != nil { log.Errorf("failed while getting Management Service public key: %v", err) return nil, fmt.Errorf("failed while getting Management Service public key") } serverKey, err := wgtypes.ParseKey(resp.Key) if err != nil { return nil, err } return &serverKey, nil } // IsHealthy probes the gRPC connection and returns false on errors func (c *GrpcClient) IsHealthy() bool { switch c.conn.GetState() { case connectivity.TransientFailure: return false case connectivity.Connecting: return true case connectivity.Shutdown: return true case connectivity.Idle: case connectivity.Ready: } ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second) defer cancel() _, err := c.realClient.GetServerKey(ctx, &proto.Empty{}) if err != nil { c.notifyDisconnected(err) log.Warnf("health check returned: %s", err) return false } c.notifyConnected() return true } func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { if !c.ready() { return nil, fmt.Errorf("no connection to management") } loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) if err != nil { log.Errorf("failed to encrypt message: %s", err) return nil, err } mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout) defer cancel() resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ WgPubKey: c.key.PublicKey().String(), Body: loginReq, }) if err != nil { return nil, err } loginResp := &proto.LoginResponse{} err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) if err != nil { log.Errorf("failed to decrypt registration message: %s", err) return nil, err } return loginResp, nil } // Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key // Takes care of encrypting and decrypting messages. // This method will also collect system info and send it with the request (e.g. hostname, os, etc) func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte) (*proto.LoginResponse, error) { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys}) } // Login attempts login to Management Server. Takes care of encrypting and decrypting messages. func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte) (*proto.LoginResponse, error) { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys}) } // GetDeviceAuthorizationFlow returns a device authorization flow information. // It also takes care of encrypting and decrypting messages. func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) { if !c.ready() { return nil, fmt.Errorf("no connection to management in order to get device authorization flow") } mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2) defer cancel() message := &proto.DeviceAuthorizationFlowRequest{} encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message) if err != nil { return nil, err } resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{ WgPubKey: c.key.PublicKey().String(), Body: encryptedMSG}, ) if err != nil { return nil, err } flowInfoResp := &proto.DeviceAuthorizationFlow{} err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp) if err != nil { errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err) log.Error(errWithMSG) return nil, errWithMSG } return flowInfoResp, nil } // GetPKCEAuthorizationFlow returns a pkce authorization flow information. // It also takes care of encrypting and decrypting messages. func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) { if !c.ready() { return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow") } mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2) defer cancel() message := &proto.PKCEAuthorizationFlowRequest{} encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message) if err != nil { return nil, err } resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{ WgPubKey: c.key.PublicKey().String(), Body: encryptedMSG, }) if err != nil { return nil, err } flowInfoResp := &proto.PKCEAuthorizationFlow{} err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp) if err != nil { errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err) log.Error(errWithMSG) return nil, errWithMSG } return flowInfoResp, nil } func (c *GrpcClient) notifyDisconnected(err error) { c.connStateCallbackLock.RLock() defer c.connStateCallbackLock.RUnlock() if c.connStateCallback == nil { return } c.connStateCallback.MarkManagementDisconnected(err) } func (c *GrpcClient) notifyConnected() { c.connStateCallbackLock.RLock() defer c.connStateCallbackLock.RUnlock() if c.connStateCallback == nil { return } c.connStateCallback.MarkManagementConnected() } func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { if info == nil { return nil } addresses := make([]*proto.NetworkAddress, 0, len(info.NetworkAddresses)) for _, addr := range info.NetworkAddresses { addresses = append(addresses, &proto.NetworkAddress{ NetIP: addr.NetIP.String(), Mac: addr.Mac, }) } return &proto.PeerSystemMeta{ Hostname: info.Hostname, GoOS: info.GoOS, OS: info.OS, Core: info.OSVersion, OSVersion: info.OSVersion, Platform: info.Platform, Kernel: info.Kernel, WiretrusteeVersion: info.WiretrusteeVersion, UiVersion: info.UIVersion, KernelVersion: info.KernelVersion, NetworkAddresses: addresses, SysSerialNumber: info.SystemSerialNumber, SysManufacturer: info.SystemManufacturer, SysProductName: info.SystemProductName, Environment: &proto.Environment{ Cloud: info.Environment.Cloud, Platform: info.Environment.Platform, }, } }