Support new Management service protocol (NetworkMap) (#193)

* feature: support new management service protocol

* chore: add more logging to track networkmap serial

* refactor: organize peer update code in engine

* chore: fix lint issues

* refactor: extract Signal client interface

* test: add signal client mock

* refactor: introduce Management Service client interface

* chore: place management and signal clients mocks to respective packages

* test: add Serial test to the engine

* fix: lint issues

* test: unit tests for a networkMapUpdate

* test: unit tests Sync update
This commit is contained in:
Mikhail Bragin 2022-01-18 16:44:58 +01:00 committed by GitHub
parent 9a3fba3fa3
commit 5db130a12e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1102 additions and 651 deletions

View File

@ -82,7 +82,7 @@ var (
)
// loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow.
func loginPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string) (*mgmProto.LoginResponse, error) {
func loginPeer(serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string) (*mgmProto.LoginResponse, error) {
loginResp, err := client.Login(serverPublicKey)
if err != nil {
@ -101,7 +101,7 @@ func loginPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string)
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func registerPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string) (*mgmProto.LoginResponse, error) {
func registerPeer(serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string) (*mgmProto.LoginResponse, error) {
var err error
if setupKey == "" {

View File

@ -83,7 +83,7 @@ func createEngineConfig(key wgtypes.Key, config *internal.Config, peerConfig *mg
}
// connectToSignal creates Signal Service client and established a connection
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.Client, error) {
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
var sigTLSEnabled bool
if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS {
sigTLSEnabled = true
@ -101,7 +101,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
}
// connectToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*mgm.Client, *mgmProto.LoginResponse, error) {
func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*mgm.GrpcClient, *mgmProto.LoginResponse, error) {
log.Debugf("connecting to management server %s", managementAddr)
client, err := mgm.NewClient(ctx, managementAddr, ourPrivateKey, tlsEnabled)
if err != nil {

View File

@ -12,6 +12,7 @@ import (
mgmProto "github.com/wiretrustee/wiretrustee/management/proto"
signal "github.com/wiretrustee/wiretrustee/signal/client"
sProto "github.com/wiretrustee/wiretrustee/signal/proto"
"github.com/wiretrustee/wiretrustee/util"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"math/rand"
"strings"
@ -44,9 +45,9 @@ type EngineConfig struct {
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
type Engine struct {
// signal is a Signal Service client
signal *signal.Client
signal signal.Client
// mgmClient is a Management Service client
mgmClient *mgm.Client
mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn
@ -64,6 +65,9 @@ type Engine struct {
ctx context.Context
wgInterface iface.WGIface
// networkSerial is the latest Serial (state ID) of the network sent by the Management service
networkSerial uint64
}
// Peer is an instance of the Connection Peer
@ -73,7 +77,7 @@ type Peer struct {
}
// NewEngine creates a new Connection Engine
func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *EngineConfig, cancel context.CancelFunc, ctx context.Context) *Engine {
func NewEngine(signalClient signal.Client, mgmClient mgm.Client, config *EngineConfig, cancel context.CancelFunc, ctx context.Context) *Engine {
return &Engine{
signal: signalClient,
mgmClient: mgmClient,
@ -84,6 +88,7 @@ func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *Engin
TURNs: []*ice.URL{},
cancel: cancel,
ctx: ctx,
networkSerial: 0,
}
}
@ -91,7 +96,7 @@ func (e *Engine) Stop() error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
err := e.removeAllPeerConnections()
err := e.removeAllPeers()
if err != nil {
return err
}
@ -146,8 +151,22 @@ func (e *Engine) Start() error {
return nil
}
func (e *Engine) removePeers(peers []string) error {
for _, p := range peers {
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
currentPeers := make([]string, 0, len(e.peerConns))
for p := range e.peerConns {
currentPeers = append(currentPeers, p)
}
newPeers := make([]string, 0, len(peersUpdate))
for _, p := range peersUpdate {
newPeers = append(newPeers, p.GetWgPubKey())
}
toRemove := util.SliceDiff(currentPeers, newPeers)
for _, p := range toRemove {
err := e.removePeer(p)
if err != nil {
return err
@ -157,7 +176,7 @@ func (e *Engine) removePeers(peers []string) error {
return nil
}
func (e *Engine) removeAllPeerConnections() error {
func (e *Engine) removeAllPeers() error {
log.Debugf("removing all peer connections")
for p := range e.peerConns {
err := e.removePeer(p)
@ -189,6 +208,16 @@ func (e *Engine) GetPeerConnectionStatus(peerKey string) peer.ConnStatus {
return -1
}
func (e *Engine) GetPeers() []string {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
peers := []string{}
for s := range e.peerConns {
peers = append(peers, s)
}
return peers
}
// GetConnectedPeers returns a connection Status or nil if peer connection wasn't found
func (e *Engine) GetConnectedPeers() []string {
@ -205,7 +234,7 @@ func (e *Engine) GetConnectedPeers() []string {
return peers
}
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s *signal.Client) error {
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
err := s.Send(&sProto.Message{
Key: myKey.PublicKey().String(),
RemoteKey: remoteKey.String(),
@ -223,7 +252,7 @@ func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtyp
return nil
}
func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.Key, s *signal.Client, isAnswer bool) error {
func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error {
var t sProto.Body_Type
if isAnswer {
@ -246,11 +275,7 @@ func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.K
return nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
go func() {
err := e.mgmClient.Sync(func(update *mgmProto.SyncResponse) error {
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@ -268,15 +293,24 @@ func (e *Engine) receiveManagementEvents() {
//todo update signal
}
if update.GetRemotePeers() != nil || update.GetRemotePeersIsEmpty() {
// empty arrays are serialized by protobuf to null, but for our case empty array is a valid state.
err := e.updatePeers(update.GetRemotePeers())
if update.GetNetworkMap() != nil {
// only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap())
if err != nil {
return err
}
}
return nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
go func() {
err := e.mgmClient.Sync(func(update *mgmProto.SyncResponse) error {
return e.handleSync(update)
})
if err != nil {
// happens if management is unavailable for a long time.
@ -327,27 +361,41 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
return nil
}
func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error {
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(remotePeers))
remotePeerMap := make(map[string]struct{})
for _, p := range remotePeers {
remotePeerMap[p.GetWgPubKey()] = struct{}{}
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
serial := networkMap.GetSerial()
if e.networkSerial > serial {
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
return nil
}
//remove peers that are no longer available for us
toRemove := []string{}
for p := range e.peerConns {
if _, ok := remotePeerMap[p]; !ok {
toRemove = append(toRemove, p)
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
if err != nil {
return err
}
}
err := e.removePeers(toRemove)
} else {
err := e.removePeers(networkMap.GetRemotePeers())
if err != nil {
return err
}
// add new peers
for _, p := range remotePeers {
err = e.addNewPeers(networkMap.GetRemotePeers())
if err != nil {
return err
}
}
e.networkSerial = serial
return nil
}
// addNewPeers finds and adds peers that were not know before but arrived from the Management service with the update
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate {
peerKey := p.GetWgPubKey()
peerIPs := p.GetAllowedIps()
if _, ok := e.peerConns[peerKey]; !ok {

View File

@ -37,6 +37,232 @@ var (
}
)
func TestEngine_UpdateNetworkMap(t *testing.T) {
// test setup
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(&signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: "utun100",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, cancel, ctx)
type testCase struct {
idx int
networkMap *mgmtProto.NetworkMap
expectedLen int
expectedPeers []string
expectedSerial uint64
}
peer1 := &mgmtProto.RemotePeerConfig{
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.10/24"},
}
peer2 := &mgmtProto.RemotePeerConfig{
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.11/24"},
}
peer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.12/24"},
}
// 1st case - new peer and network map has Serial grater than local => apply the update
case1 := testCase{
idx: 1,
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{
peer1,
},
RemotePeersIsEmpty: false,
},
expectedLen: 1,
expectedPeers: []string{peer1.GetWgPubKey()},
expectedSerial: 1,
}
// 2nd case - one extra peer added and network map has Serial grater than local => apply the update
case2 := testCase{
idx: 2,
networkMap: &mgmtProto.NetworkMap{
Serial: 2,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{
peer1, peer2,
},
RemotePeersIsEmpty: false,
},
expectedLen: 2,
expectedPeers: []string{peer1.GetWgPubKey(), peer2.GetWgPubKey()},
expectedSerial: 2,
}
// 3rd case - an update with 3 peers and Serial lower than the current serial of the engine => ignore the update
case3 := testCase{
idx: 3,
networkMap: &mgmtProto.NetworkMap{
Serial: 0,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{
peer1, peer2, peer3,
},
RemotePeersIsEmpty: false,
},
expectedLen: 2,
expectedPeers: []string{peer1.GetWgPubKey(), peer2.GetWgPubKey()},
expectedSerial: 2,
}
// 4th case - an update with 2 peers (1 new and 1 old) => apply the update removing old peer and adding a new one
case4 := testCase{
idx: 3,
networkMap: &mgmtProto.NetworkMap{
Serial: 4,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{
peer2, peer3,
},
RemotePeersIsEmpty: false,
},
expectedLen: 2,
expectedPeers: []string{peer2.GetWgPubKey(), peer3.GetWgPubKey()},
expectedSerial: 4,
}
// 5th case - an update with all peers to be removed
case5 := testCase{
idx: 3,
networkMap: &mgmtProto.NetworkMap{
Serial: 5,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
},
expectedLen: 0,
expectedPeers: nil,
expectedSerial: 5,
}
for _, c := range []testCase{case1, case2, case3, case4, case5} {
err = engine.updateNetworkMap(c.networkMap)
if err != nil {
t.Fatal(err)
return
}
if len(engine.peerConns) != c.expectedLen {
t.Errorf("case %d expecting Engine.peerConns to be of size %d, got %d", c.idx, c.expectedLen, len(engine.peerConns))
}
if engine.networkSerial != c.expectedSerial {
t.Errorf("case %d expecting Engine.networkSerial to be equal to %d, actual %d", c.idx, c.expectedSerial, engine.networkSerial)
}
for _, p := range c.expectedPeers {
if _, ok := engine.peerConns[p]; !ok {
t.Errorf("case %d expecting Engine.peerConns to contain peer %s", c.idx, p)
}
}
}
}
func TestEngine_Sync(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
t.Fatal(err)
}
}
return nil
}
engine := NewEngine(&signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{
WgIfaceName: "utun100",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, cancel, ctx)
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
err = engine.Start()
if err != nil {
t.Fatal(err)
return
}
peer1 := &mgmtProto.RemotePeerConfig{
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.10/24"},
}
peer2 := &mgmtProto.RemotePeerConfig{
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.11/24"},
}
peer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.12/24"},
}
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
updates <- &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{
Serial: 10,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
RemotePeersIsEmpty: false,
},
}
timeout := time.After(time.Second * 2)
for {
select {
case <-timeout:
t.Fatalf("timeout while waiting for test to finish")
default:
}
if len(engine.GetPeers()) == 3 && engine.networkSerial == 10 {
break
}
}
}
func TestEngine_MultiplePeers(t *testing.T) {
//log.SetLevel(log.DebugLevel)
@ -58,23 +284,14 @@ func TestEngine_MultiplePeers(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sport := 10010
signalServer, err := startSignal(sport)
sigServer, err := startSignal(sport)
if err != nil {
t.Fatal(err)
return
}
defer signalServer.Stop()
defer sigServer.Stop()
mport := 33081
mgmtServer, err := startManagement(mport, &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
Signal: &server.Host{
Proto: "http",
URI: "localhost:10000",
},
Datadir: dir,
HttpConfig: nil,
})
mgmtServer, err := startManagement(mport, dir)
if err != nil {
t.Fatal(err)
return
@ -201,7 +418,18 @@ func startSignal(port int) (*grpc.Server, error) {
return s, nil
}
func startManagement(port int, config *server.Config) (*grpc.Server, error) {
func startManagement(port int, dataDir string) (*grpc.Server, error) {
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
Signal: &server.Host{
Proto: "http",
URI: "localhost:10000",
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {

View File

@ -419,3 +419,7 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
}
}()
}
func (conn *Conn) GetKey() string {
return conn.config.Key
}

View File

@ -1,253 +1,15 @@
package client
import (
"context"
"crypto/tls"
"fmt"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/client/system"
"github.com/wiretrustee/wiretrustee/encryption"
"github.com/wiretrustee/wiretrustee/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"io"
"time"
)
type Client struct {
key wgtypes.Key
realClient proto.ManagementServiceClient
ctx context.Context
conn *grpc.ClientConn
}
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*Client, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}
mgmCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 15 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil {
log.Errorf("failed creating connection to Management Service %v", err)
return nil, err
}
realClient := proto.NewManagementServiceClient(conn)
return &Client{
key: ourPrivateKey,
realClient: realClient,
ctx: ctx,
conn: conn,
}, nil
}
// Close closes connection to the Management Service
func (c *Client) Close() error {
return c.conn.Close()
}
//defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// ready indicates whether the client is okay and ready to be used
// for now it just checks whether gRPC connection to the service is ready
func (c *Client) ready() bool {
return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle
}
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
var backOff = defaultBackoff(c.ctx)
operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
if !c.ready() {
return fmt.Errorf("no connection to management")
}
// todo we already have it since we did the Login, maybe cache it locally?
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Errorf("failed getting Management Service public key: %s", err)
return err
}
stream, err := c.connectToStream(*serverPubKey)
if err != nil {
log.Errorf("failed to open Management Service stream: %s", err)
return err
}
log.Infof("connected to the Management Service stream")
// blocking until error
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
if err != nil {
backOff.Reset()
return err
}
return nil
}
err := backoff.Retry(operation, backOff)
if err != nil {
log.Warnf("exiting Management Service connection retry loop due to unrecoverable error: %s", err)
return err
}
return nil
}
func (c *Client) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{}
myPrivateKey := c.key
myPublicKey := myPrivateKey.PublicKey()
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
if err != nil {
log.Errorf("failed encrypting message: %s", err)
return nil, err
}
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
return c.realClient.Sync(c.ctx, syncReq)
}
func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
for {
update, err := stream.Recv()
if err == io.EOF {
log.Errorf("Management stream has been closed by server: %s", err)
return err
}
if err != nil {
log.Warnf("disconnected from Management Service sync stream: %v", err)
return err
}
log.Debugf("got an update message from Management Service")
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return err
}
err = msgHandler(decryptedResp)
if err != nil {
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
return err
}
}
}
// GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server)
func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
if err != nil {
return nil, err
}
serverKey, err := wgtypes.ParseKey(resp.Key)
if err != nil {
return nil, err
}
return &serverKey, nil
}
func (c *Client) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
defer cancel()
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt registration message: %s", err)
return nil, err
}
return loginResp, nil
}
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *Client) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) {
gi := system.GetInfo()
meta := &proto.PeerSystemMeta{
Hostname: gi.Hostname,
GoOS: gi.GoOS,
OS: gi.OS,
Core: gi.OSVersion,
Platform: gi.Platform,
Kernel: gi.Kernel,
WiretrusteeVersion: "",
}
log.Debugf("detected system %v", meta)
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: meta})
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *Client) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) {
return c.login(serverKey, &proto.LoginRequest{})
type Client interface {
io.Closer
Sync(msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key) (*proto.LoginResponse, error)
}

View File

@ -16,7 +16,7 @@ import (
"time"
)
var tested *Client
var tested *GrpcClient
var serverAddr string
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"

253
management/client/grpc.go Normal file
View File

@ -0,0 +1,253 @@
package client
import (
"context"
"crypto/tls"
"fmt"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/client/system"
"github.com/wiretrustee/wiretrustee/encryption"
"github.com/wiretrustee/wiretrustee/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"io"
"time"
)
type GrpcClient struct {
key wgtypes.Key
realClient proto.ManagementServiceClient
ctx context.Context
conn *grpc.ClientConn
}
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}
mgmCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 15 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil {
log.Errorf("failed creating connection to Management Service %v", err)
return nil, err
}
realClient := proto.NewManagementServiceClient(conn)
return &GrpcClient{
key: ourPrivateKey,
realClient: realClient,
ctx: ctx,
conn: conn,
}, nil
}
// Close closes connection to the Management Service
func (c *GrpcClient) Close() error {
return c.conn.Close()
}
//defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// ready indicates whether the client is okay and ready to be used
// for now it just checks whether gRPC connection to the service is ready
func (c *GrpcClient) ready() bool {
return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle
}
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
var backOff = defaultBackoff(c.ctx)
operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
if !c.ready() {
return fmt.Errorf("no connection to management")
}
// todo we already have it since we did the Login, maybe cache it locally?
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Errorf("failed getting Management Service public key: %s", err)
return err
}
stream, err := c.connectToStream(*serverPubKey)
if err != nil {
log.Errorf("failed to open Management Service stream: %s", err)
return err
}
log.Infof("connected to the Management Service stream")
// blocking until error
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
if err != nil {
backOff.Reset()
return err
}
return nil
}
err := backoff.Retry(operation, backOff)
if err != nil {
log.Warnf("exiting Management Service connection retry loop due to unrecoverable error: %s", err)
return err
}
return nil
}
func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{}
myPrivateKey := c.key
myPublicKey := myPrivateKey.PublicKey()
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
if err != nil {
log.Errorf("failed encrypting message: %s", err)
return nil, err
}
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
return c.realClient.Sync(c.ctx, syncReq)
}
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
for {
update, err := stream.Recv()
if err == io.EOF {
log.Errorf("Management stream has been closed by server: %s", err)
return err
}
if err != nil {
log.Warnf("disconnected from Management Service sync stream: %v", err)
return err
}
log.Debugf("got an update message from Management Service")
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return err
}
err = msgHandler(decryptedResp)
if err != nil {
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
return err
}
}
}
// GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
if err != nil {
return nil, err
}
serverKey, err := wgtypes.ParseKey(resp.Key)
if err != nil {
return nil, err
}
return &serverKey, nil
}
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
defer cancel()
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt registration message: %s", err)
return nil, err
}
return loginResp, nil
}
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) {
gi := system.GetInfo()
meta := &proto.PeerSystemMeta{
Hostname: gi.Hostname,
GoOS: gi.GoOS,
OS: gi.OS,
Core: gi.OSVersion,
Platform: gi.Platform,
Kernel: gi.Kernel,
WiretrusteeVersion: "",
}
log.Debugf("detected system %v", meta)
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: meta})
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) {
return c.login(serverKey, &proto.LoginRequest{})
}

49
management/client/mock.go Normal file
View File

@ -0,0 +1,49 @@
package client
import (
"github.com/wiretrustee/wiretrustee/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MockClient struct {
CloseFunc func() error
SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key) (*proto.LoginResponse, error)
}
func (m *MockClient) Close() error {
if m.CloseFunc == nil {
return nil
}
return m.CloseFunc()
}
func (m *MockClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
if m.SyncFunc == nil {
return nil
}
return m.SyncFunc(msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
if m.GetServerPublicKeyFunc == nil {
return nil, nil
}
return m.GetServerPublicKeyFunc()
}
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) {
if m.RegisterFunc == nil {
return nil, nil
}
return m.RegisterFunc(serverKey, setupKey)
}
func (m *MockClient) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) {
if m.LoginFunc == nil {
return nil, nil
}
return m.LoginFunc(serverKey)
}

View File

@ -1,26 +1,11 @@
package client
import (
"context"
"crypto/tls"
"fmt"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/encryption"
"github.com/wiretrustee/wiretrustee/signal/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"io"
"strings"
"sync"
"time"
)
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
@ -31,317 +16,15 @@ type Status string
const StreamConnected Status = "Connected"
const StreamDisconnected Status = "Disconnected"
// Client Wraps the Signal Exchange Service gRpc client
type Client struct {
key wgtypes.Key
realClient proto.SignalExchangeClient
signalConn *grpc.ClientConn
ctx context.Context
stream proto.SignalExchange_ConnectStreamClient
// connectedCh used to notify goroutines waiting for the connection to the Signal stream
connectedCh chan struct{}
mux sync.Mutex
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
status Status
}
func (c *Client) StreamConnected() bool {
return c.status == StreamConnected
}
func (c *Client) GetStatus() Status {
return c.status
}
// Close Closes underlying connections to the Signal Exchange
func (c *Client) Close() error {
return c.signalConn.Close()
}
// NewClient creates a new Signal client
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*Client, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}
sigCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
conn, err := grpc.DialContext(
sigCtx,
addr,
transportOption,
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 15 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil {
log.Errorf("failed to connect to the signalling server %v", err)
return nil, err
}
return &Client{
realClient: proto.NewSignalExchangeClient(conn),
ctx: ctx,
signalConn: conn,
key: key,
mux: sync.Mutex{},
status: StreamDisconnected,
}, nil
}
//defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// Receive Connects to the Signal Exchange message stream and starts receiving messages.
// The messages will be handled by msgHandler function provided.
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
func (c *Client) Receive(msgHandler func(msg *proto.Message) error) error {
var backOff = defaultBackoff(c.ctx)
operation := func() error {
c.notifyStreamDisconnected()
log.Debugf("signal connection state %v", c.signalConn.GetState())
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
// connect to Signal stream identifying ourselves with a public Wireguard key
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
stream, err := c.connect(c.key.PublicKey().String())
if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
return err
}
c.notifyStreamConnected()
log.Infof("connected to the Signal Service stream")
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler)
if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
backOff.Reset()
return err
}
return nil
}
err := backoff.Retry(operation, backOff)
if err != nil {
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err)
return err
}
return nil
}
func (c *Client) notifyStreamDisconnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = StreamDisconnected
}
func (c *Client) notifyStreamConnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = StreamConnected
if c.connectedCh != nil {
// there are goroutines waiting on this channel -> release them
close(c.connectedCh)
c.connectedCh = nil
}
}
func (c *Client) getStreamStatusChan() <-chan struct{} {
c.mux.Lock()
defer c.mux.Unlock()
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
return c.connectedCh
}
func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
c.stream = nil
// add key fingerprint to the request header to be identified on the server side
md := metadata.New(map[string]string{proto.HeaderId: key})
ctx := metadata.NewOutgoingContext(c.ctx, md)
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
c.stream = stream
if err != nil {
return nil, err
}
// blocks
header, err := c.stream.Header()
if err != nil {
return nil, err
}
registered := header.Get(proto.HeaderRegistered)
if len(registered) == 0 {
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
}
return stream, nil
}
// Ready indicates whether the client is okay and Ready to be used
// for now it just checks whether gRPC connection to the service is in state Ready
func (c *Client) Ready() bool {
return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle
}
// WaitStreamConnected waits until the client is connected to the Signal stream
func (c *Client) WaitStreamConnected() {
if c.status == StreamConnected {
return
}
ch := c.getStreamStatusChan()
select {
case <-c.ctx.Done():
case <-ch:
}
}
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
// The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
// Client.connWg can be used to wait
func (c *Client) SendToStream(msg *proto.EncryptedMessage) error {
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
if c.stream == nil {
return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages")
}
err := c.stream.Send(msg)
if err != nil {
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
return err
}
return nil
}
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key
func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) {
remoteKey, err := wgtypes.ParseKey(msg.GetKey())
if err != nil {
return nil, err
}
body := &proto.Body{}
err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
if err != nil {
return nil, err
}
return &proto.Message{
Key: msg.Key,
RemoteKey: msg.RemoteKey,
Body: body,
}, nil
}
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
if err != nil {
return nil, err
}
encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
if err != nil {
return nil, err
}
return &proto.EncryptedMessage{
Key: msg.GetKey(),
RemoteKey: msg.GetRemoteKey(),
Body: encryptedBody,
}, nil
}
// Send sends a message to the remote Peer through the Signal Exchange.
func (c *Client) Send(msg *proto.Message) error {
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
encryptedMessage, err := c.encryptMessage(msg)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err = c.realClient.Send(ctx, encryptedMessage)
if err != nil {
return err
}
return nil
}
// receive receives messages from other peers coming through the Signal Exchange
func (c *Client) receive(stream proto.SignalExchange_ConnectStreamClient,
msgHandler func(msg *proto.Message) error) error {
for {
msg, err := stream.Recv()
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
log.Warnf("stream canceled (usually indicates shutdown)")
return err
} else if s.Code() == codes.Unavailable {
log.Warnf("Signal Service is unavailable")
return err
} else if err == io.EOF {
log.Warnf("Signal Service stream closed by server")
return err
} else if err != nil {
return err
}
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key)
decryptedMessage, err := c.decryptMessage(msg)
if err != nil {
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
}
err = msgHandler(decryptedMessage)
if err != nil {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
//todo send something??
}
}
type Client interface {
io.Closer
StreamConnected() bool
GetStatus() Status
Receive(msgHandler func(msg *proto.Message) error) error
Ready() bool
WaitStreamConnected()
SendToStream(msg *proto.EncryptedMessage) error
Send(msg *proto.Message) error
}
// UnMarshalCredential parses the credentials from the message and returns a Credential instance
@ -369,7 +52,7 @@ func MarshalCredential(myKey wgtypes.Key, remoteKey wgtypes.Key, credential *Cre
}, nil
}
// Credential is an instance of a Client's Credential
// Credential is an instance of a GrpcClient's Credential
type Credential struct {
UFrag string
Pwd string

View File

@ -17,7 +17,7 @@ import (
"time"
)
var _ = Describe("Client", func() {
var _ = Describe("GrpcClient", func() {
var (
addr string
@ -160,7 +160,7 @@ var _ = Describe("Client", func() {
})
func createSignalClient(addr string, key wgtypes.Key) *Client {
func createSignalClient(addr string, key wgtypes.Key) *GrpcClient {
var sigTLSEnabled = false
client, err := NewClient(context.Background(), addr, key, sigTLSEnabled)
if err != nil {

336
signal/client/grpc.go Normal file
View File

@ -0,0 +1,336 @@
package client
import (
"context"
"crypto/tls"
"fmt"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/encryption"
"github.com/wiretrustee/wiretrustee/signal/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"io"
"sync"
"time"
)
// GrpcClient Wraps the Signal Exchange Service gRpc client
type GrpcClient struct {
key wgtypes.Key
realClient proto.SignalExchangeClient
signalConn *grpc.ClientConn
ctx context.Context
stream proto.SignalExchange_ConnectStreamClient
// connectedCh used to notify goroutines waiting for the connection to the Signal stream
connectedCh chan struct{}
mux sync.Mutex
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
status Status
}
func (c *GrpcClient) StreamConnected() bool {
return c.status == StreamConnected
}
func (c *GrpcClient) GetStatus() Status {
return c.status
}
// Close Closes underlying connections to the Signal Exchange
func (c *GrpcClient) Close() error {
return c.signalConn.Close()
}
// NewClient creates a new Signal client
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}
sigCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
conn, err := grpc.DialContext(
sigCtx,
addr,
transportOption,
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 15 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil {
log.Errorf("failed to connect to the signalling server %v", err)
return nil, err
}
return &GrpcClient{
realClient: proto.NewSignalExchangeClient(conn),
ctx: ctx,
signalConn: conn,
key: key,
mux: sync.Mutex{},
status: StreamDisconnected,
}, nil
}
//defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// Receive Connects to the Signal Exchange message stream and starts receiving messages.
// The messages will be handled by msgHandler function provided.
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
var backOff = defaultBackoff(c.ctx)
operation := func() error {
c.notifyStreamDisconnected()
log.Debugf("signal connection state %v", c.signalConn.GetState())
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
// connect to Signal stream identifying ourselves with a public Wireguard key
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
stream, err := c.connect(c.key.PublicKey().String())
if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
return err
}
c.notifyStreamConnected()
log.Infof("connected to the Signal Service stream")
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler)
if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
backOff.Reset()
return err
}
return nil
}
err := backoff.Retry(operation, backOff)
if err != nil {
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err)
return err
}
return nil
}
func (c *GrpcClient) notifyStreamDisconnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = StreamDisconnected
}
func (c *GrpcClient) notifyStreamConnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = StreamConnected
if c.connectedCh != nil {
// there are goroutines waiting on this channel -> release them
close(c.connectedCh)
c.connectedCh = nil
}
}
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
c.mux.Lock()
defer c.mux.Unlock()
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
return c.connectedCh
}
func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
c.stream = nil
// add key fingerprint to the request header to be identified on the server side
md := metadata.New(map[string]string{proto.HeaderId: key})
ctx := metadata.NewOutgoingContext(c.ctx, md)
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
c.stream = stream
if err != nil {
return nil, err
}
// blocks
header, err := c.stream.Header()
if err != nil {
return nil, err
}
registered := header.Get(proto.HeaderRegistered)
if len(registered) == 0 {
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
}
return stream, nil
}
// Ready indicates whether the client is okay and Ready to be used
// for now it just checks whether gRPC connection to the service is in state Ready
func (c *GrpcClient) Ready() bool {
return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle
}
// WaitStreamConnected waits until the client is connected to the Signal stream
func (c *GrpcClient) WaitStreamConnected() {
if c.status == StreamConnected {
return
}
ch := c.getStreamStatusChan()
select {
case <-c.ctx.Done():
case <-ch:
}
}
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
// The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
// GrpcClient.connWg can be used to wait
func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error {
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
if c.stream == nil {
return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call GrpcClient.Receive before sending messages")
}
err := c.stream.Send(msg)
if err != nil {
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
return err
}
return nil
}
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key
func (c *GrpcClient) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) {
remoteKey, err := wgtypes.ParseKey(msg.GetKey())
if err != nil {
return nil, err
}
body := &proto.Body{}
err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
if err != nil {
return nil, err
}
return &proto.Message{
Key: msg.Key,
RemoteKey: msg.RemoteKey,
Body: body,
}, nil
}
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
func (c *GrpcClient) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
if err != nil {
return nil, err
}
encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
if err != nil {
return nil, err
}
return &proto.EncryptedMessage{
Key: msg.GetKey(),
RemoteKey: msg.GetRemoteKey(),
Body: encryptedBody,
}, nil
}
// Send sends a message to the remote Peer through the Signal Exchange.
func (c *GrpcClient) Send(msg *proto.Message) error {
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
encryptedMessage, err := c.encryptMessage(msg)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err = c.realClient.Send(ctx, encryptedMessage)
if err != nil {
return err
}
return nil
}
// receive receives messages from other peers coming through the Signal Exchange
func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
msgHandler func(msg *proto.Message) error) error {
for {
msg, err := stream.Recv()
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
log.Warnf("stream canceled (usually indicates shutdown)")
return err
} else if s.Code() == codes.Unavailable {
log.Warnf("Signal Service is unavailable")
return err
} else if err == io.EOF {
log.Warnf("Signal Service stream closed by server")
return err
} else if err != nil {
return err
}
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key)
decryptedMessage, err := c.decryptMessage(msg)
if err != nil {
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
}
err = msgHandler(decryptedMessage)
if err != nil {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
//todo send something??
}
}
}

72
signal/client/mock.go Normal file
View File

@ -0,0 +1,72 @@
package client
import (
"github.com/wiretrustee/wiretrustee/signal/proto"
)
type MockClient struct {
CloseFunc func() error
GetStatusFunc func() Status
StreamConnectedFunc func() bool
ReadyFunc func() bool
WaitStreamConnectedFunc func()
ReceiveFunc func(msgHandler func(msg *proto.Message) error) error
SendToStreamFunc func(msg *proto.EncryptedMessage) error
SendFunc func(msg *proto.Message) error
}
func (sm *MockClient) Close() error {
if sm.CloseFunc == nil {
return nil
}
return sm.CloseFunc()
}
func (sm *MockClient) GetStatus() Status {
if sm.GetStatusFunc == nil {
return ""
}
return sm.GetStatusFunc()
}
func (sm *MockClient) StreamConnected() bool {
if sm.StreamConnectedFunc == nil {
return false
}
return sm.StreamConnectedFunc()
}
func (sm *MockClient) Ready() bool {
if sm.ReadyFunc == nil {
return false
}
return sm.ReadyFunc()
}
func (sm *MockClient) WaitStreamConnected() {
if sm.WaitStreamConnectedFunc == nil {
return
}
sm.WaitStreamConnectedFunc()
}
func (sm *MockClient) Receive(msgHandler func(msg *proto.Message) error) error {
if sm.ReceiveFunc == nil {
return nil
}
return sm.ReceiveFunc(msgHandler)
}
func (sm *MockClient) SendToStream(msg *proto.EncryptedMessage) error {
if sm.SendToStreamFunc == nil {
return nil
}
return sm.SendToStreamFunc(msg)
}
func (sm *MockClient) Send(msg *proto.Message) error {
if sm.SendFunc == nil {
return nil
}
return sm.SendFunc(msg)
}

16
util/common.go Normal file
View File

@ -0,0 +1,16 @@
package util
// SliceDiff returns the elements in slice `x` that are not in slice `y`
func SliceDiff(x, y []string) []string {
mapY := make(map[string]struct{}, len(y))
for _, val := range y {
mapY[val] = struct{}{}
}
var diff []string
for _, val := range x {
if _, found := mapY[val]; !found {
diff = append(diff, val)
}
}
return diff
}