netbird/management/client/grpc.go
Viktor Liu 2475473227
Support client default routes for Linux (#1667)
All routes are now installed in a custom netbird routing table.
Management and wireguard traffic is now marked with a custom fwmark.
When the mark is present the traffic is routed via the main routing table, bypassing the VPN.
When the mark is absent the traffic is routed via the netbird routing table, if:
- there's no match in the main routing table
- it would match the default route in the routing table

IPv6 traffic is blocked when a default route IPv4 route is configured to avoid leakage.
2024-03-21 16:49:28 +01:00

487 lines
15 KiB
Go

package client
import (
"context"
"crypto/tls"
"fmt"
"io"
"sync"
"time"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
nbgrpc "github.com/netbirdio/netbird/util/grpc"
)
const ConnectTimeout = 10 * time.Second
// ConnStateNotifier is a wrapper interface of the status recorders
type ConnStateNotifier interface {
MarkManagementDisconnected(error)
MarkManagementConnected()
}
type GrpcClient struct {
key wgtypes.Key
realClient proto.ManagementServiceClient
ctx context.Context
conn *grpc.ClientConn
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
}
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil {
log.Errorf("failed creating connection to Management Service %v", err)
return nil, err
}
realClient := proto.NewManagementServiceClient(conn)
return &GrpcClient{
key: ourPrivateKey,
realClient: realClient,
ctx: ctx,
conn: conn,
connStateCallbackLock: sync.RWMutex{},
}, nil
}
// Close closes connection to the Management Service
func (c *GrpcClient) Close() error {
return c.conn.Close()
}
// SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock()
defer c.connStateCallbackLock.Unlock()
c.connStateCallback = notifier
}
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
Multiplier: 1.7,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// ready indicates whether the client is okay and ready to be used
// for now it just checks whether gRPC connection to the service is ready
func (c *GrpcClient) ready() bool {
return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle
}
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
backOff := defaultBackoff(c.ctx)
operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
connState := c.conn.GetState()
if connState == connectivity.Shutdown {
return backoff.Permanent(fmt.Errorf("connection to management has been shut down"))
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
c.conn.WaitForStateChange(c.ctx, connState)
return fmt.Errorf("connection to management is not ready and in %s state", connState)
}
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
}
return err
}
log.Infof("connected to the Management Service stream")
c.notifyConnected()
// blocking until error
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
if err != nil {
s, _ := gstatus.FromError(err)
switch s.Code() {
case codes.PermissionDenied:
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
default:
backOff.Reset() // reset backoff counter after successful connection
c.notifyDisconnected(err)
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
}
return nil
}
err := backoff.Retry(operation, backOff)
if err != nil {
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
return err
}
return nil
}
// GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
}
defer func() {
_ = stream.CloseSend()
}()
update, err := stream.Recv()
if err == io.EOF {
log.Debugf("Management stream has been closed by server: %s", err)
return nil, err
}
if err != nil {
log.Debugf("disconnected from Management Service sync stream: %v", err)
return nil, err
}
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return nil, err
}
if decryptedResp.GetNetworkMap() == nil {
return nil, fmt.Errorf("invalid msg, required network map")
}
return decryptedResp.GetNetworkMap(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{}
myPrivateKey := c.key
myPublicKey := myPrivateKey.PublicKey()
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
if err != nil {
log.Errorf("failed encrypting message: %s", err)
return nil, err
}
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
sync, err := c.realClient.Sync(ctx, syncReq)
if err != nil {
return nil, err
}
return sync, nil
}
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
for {
update, err := stream.Recv()
if err == io.EOF {
log.Debugf("Management stream has been closed by server: %s", err)
return err
}
if err != nil {
log.Debugf("disconnected from Management Service sync stream: %v", err)
return err
}
log.Debugf("got an update message from Management Service")
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return err
}
err = msgHandler(decryptedResp)
if err != nil {
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
return err
}
}
}
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
if err != nil {
return nil, err
}
serverKey, err := wgtypes.ParseKey(resp.Key)
if err != nil {
return nil, err
}
return &serverKey, nil
}
// IsHealthy probes the gRPC connection and returns false on errors
func (c *GrpcClient) IsHealthy() bool {
switch c.conn.GetState() {
case connectivity.TransientFailure:
return false
case connectivity.Connecting:
return true
case connectivity.Shutdown:
return true
case connectivity.Idle:
case connectivity.Ready:
}
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
defer cancel()
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
if err != nil {
c.notifyDisconnected(err)
log.Warnf("health check returned: %s", err)
return false
}
c.notifyConnected()
return true
}
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt registration message: %s", err)
return nil, err
}
return loginResp, nil
}
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys})
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys})
}
// GetDeviceAuthorizationFlow returns a device authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}
resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG},
)
if err != nil {
return nil, err
}
flowInfoResp := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}
return flowInfoResp, nil
}
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
message := &proto.PKCEAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}
resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG,
})
if err != nil {
return nil, err
}
flowInfoResp := &proto.PKCEAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}
return flowInfoResp, nil
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkManagementDisconnected(err)
}
func (c *GrpcClient) notifyConnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkManagementConnected()
}
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
if info == nil {
return nil
}
addresses := make([]*proto.NetworkAddress, 0, len(info.NetworkAddresses))
for _, addr := range info.NetworkAddresses {
addresses = append(addresses, &proto.NetworkAddress{
NetIP: addr.NetIP.String(),
Mac: addr.Mac,
})
}
return &proto.PeerSystemMeta{
Hostname: info.Hostname,
GoOS: info.GoOS,
OS: info.OS,
Core: info.OSVersion,
OSVersion: info.OSVersion,
Platform: info.Platform,
Kernel: info.Kernel,
WiretrusteeVersion: info.WiretrusteeVersion,
UiVersion: info.UIVersion,
KernelVersion: info.KernelVersion,
NetworkAddresses: addresses,
SysSerialNumber: info.SystemSerialNumber,
SysManufacturer: info.SystemManufacturer,
SysProductName: info.SystemProductName,
Environment: &proto.Environment{
Cloud: info.Environment.Cloud,
Platform: info.Environment.Platform,
},
}
}