mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-24 15:48:52 +01:00
Support new Management service protocol (NetworkMap) (#193)
* feature: support new management service protocol * chore: add more logging to track networkmap serial * refactor: organize peer update code in engine * chore: fix lint issues * refactor: extract Signal client interface * test: add signal client mock * refactor: introduce Management Service client interface * chore: place management and signal clients mocks to respective packages * test: add Serial test to the engine * fix: lint issues * test: unit tests for a networkMapUpdate * test: unit tests Sync update
This commit is contained in:
parent
9a3fba3fa3
commit
5db130a12e
@ -82,7 +82,7 @@ var (
|
||||
)
|
||||
|
||||
// loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow.
|
||||
func loginPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string) (*mgmProto.LoginResponse, error) {
|
||||
func loginPeer(serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
loginResp, err := client.Login(serverPublicKey)
|
||||
if err != nil {
|
||||
@ -101,7 +101,7 @@ func loginPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string)
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func registerPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string) (*mgmProto.LoginResponse, error) {
|
||||
func registerPeer(serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
var err error
|
||||
if setupKey == "" {
|
||||
|
@ -83,7 +83,7 @@ func createEngineConfig(key wgtypes.Key, config *internal.Config, peerConfig *mg
|
||||
}
|
||||
|
||||
// connectToSignal creates Signal Service client and established a connection
|
||||
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.Client, error) {
|
||||
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
|
||||
var sigTLSEnabled bool
|
||||
if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS {
|
||||
sigTLSEnabled = true
|
||||
@ -101,7 +101,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
|
||||
}
|
||||
|
||||
// connectToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
||||
func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*mgm.Client, *mgmProto.LoginResponse, error) {
|
||||
func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*mgm.GrpcClient, *mgmProto.LoginResponse, error) {
|
||||
log.Debugf("connecting to management server %s", managementAddr)
|
||||
client, err := mgm.NewClient(ctx, managementAddr, ourPrivateKey, tlsEnabled)
|
||||
if err != nil {
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
mgmProto "github.com/wiretrustee/wiretrustee/management/proto"
|
||||
signal "github.com/wiretrustee/wiretrustee/signal/client"
|
||||
sProto "github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
"github.com/wiretrustee/wiretrustee/util"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"math/rand"
|
||||
"strings"
|
||||
@ -44,9 +45,9 @@ type EngineConfig struct {
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
type Engine struct {
|
||||
// signal is a Signal Service client
|
||||
signal *signal.Client
|
||||
signal signal.Client
|
||||
// mgmClient is a Management Service client
|
||||
mgmClient *mgm.Client
|
||||
mgmClient mgm.Client
|
||||
// peerConns is a map that holds all the peers that are known to this peer
|
||||
peerConns map[string]*peer.Conn
|
||||
|
||||
@ -64,6 +65,9 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
|
||||
wgInterface iface.WGIface
|
||||
|
||||
// networkSerial is the latest Serial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@ -73,7 +77,7 @@ type Peer struct {
|
||||
}
|
||||
|
||||
// NewEngine creates a new Connection Engine
|
||||
func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *EngineConfig, cancel context.CancelFunc, ctx context.Context) *Engine {
|
||||
func NewEngine(signalClient signal.Client, mgmClient mgm.Client, config *EngineConfig, cancel context.CancelFunc, ctx context.Context) *Engine {
|
||||
return &Engine{
|
||||
signal: signalClient,
|
||||
mgmClient: mgmClient,
|
||||
@ -84,6 +88,7 @@ func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *Engin
|
||||
TURNs: []*ice.URL{},
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
networkSerial: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,7 +96,7 @@ func (e *Engine) Stop() error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
err := e.removeAllPeerConnections()
|
||||
err := e.removeAllPeers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -146,8 +151,22 @@ func (e *Engine) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) removePeers(peers []string) error {
|
||||
for _, p := range peers {
|
||||
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service
|
||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
|
||||
currentPeers := make([]string, 0, len(e.peerConns))
|
||||
for p := range e.peerConns {
|
||||
currentPeers = append(currentPeers, p)
|
||||
}
|
||||
|
||||
newPeers := make([]string, 0, len(peersUpdate))
|
||||
for _, p := range peersUpdate {
|
||||
newPeers = append(newPeers, p.GetWgPubKey())
|
||||
}
|
||||
|
||||
toRemove := util.SliceDiff(currentPeers, newPeers)
|
||||
|
||||
for _, p := range toRemove {
|
||||
err := e.removePeer(p)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -157,7 +176,7 @@ func (e *Engine) removePeers(peers []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) removeAllPeerConnections() error {
|
||||
func (e *Engine) removeAllPeers() error {
|
||||
log.Debugf("removing all peer connections")
|
||||
for p := range e.peerConns {
|
||||
err := e.removePeer(p)
|
||||
@ -189,6 +208,16 @@ func (e *Engine) GetPeerConnectionStatus(peerKey string) peer.ConnStatus {
|
||||
|
||||
return -1
|
||||
}
|
||||
func (e *Engine) GetPeers() []string {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
peers := []string{}
|
||||
for s := range e.peerConns {
|
||||
peers = append(peers, s)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
// GetConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||
func (e *Engine) GetConnectedPeers() []string {
|
||||
@ -205,7 +234,7 @@ func (e *Engine) GetConnectedPeers() []string {
|
||||
return peers
|
||||
}
|
||||
|
||||
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s *signal.Client) error {
|
||||
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
|
||||
err := s.Send(&sProto.Message{
|
||||
Key: myKey.PublicKey().String(),
|
||||
RemoteKey: remoteKey.String(),
|
||||
@ -223,7 +252,7 @@ func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtyp
|
||||
return nil
|
||||
}
|
||||
|
||||
func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.Key, s *signal.Client, isAnswer bool) error {
|
||||
func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error {
|
||||
|
||||
var t sProto.Body_Type
|
||||
if isAnswer {
|
||||
@ -246,11 +275,7 @@ func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.K
|
||||
return nil
|
||||
}
|
||||
|
||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
go func() {
|
||||
err := e.mgmClient.Sync(func(update *mgmProto.SyncResponse) error {
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@ -268,15 +293,24 @@ func (e *Engine) receiveManagementEvents() {
|
||||
//todo update signal
|
||||
}
|
||||
|
||||
if update.GetRemotePeers() != nil || update.GetRemotePeersIsEmpty() {
|
||||
// empty arrays are serialized by protobuf to null, but for our case empty array is a valid state.
|
||||
err := e.updatePeers(update.GetRemotePeers())
|
||||
if update.GetNetworkMap() != nil {
|
||||
// only apply new changes and ignore old ones
|
||||
err := e.updateNetworkMap(update.GetNetworkMap())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
go func() {
|
||||
err := e.mgmClient.Sync(func(update *mgmProto.SyncResponse) error {
|
||||
return e.handleSync(update)
|
||||
})
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
@ -327,27 +361,41 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(remotePeers))
|
||||
remotePeerMap := make(map[string]struct{})
|
||||
for _, p := range remotePeers {
|
||||
remotePeerMap[p.GetWgPubKey()] = struct{}{}
|
||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
serial := networkMap.GetSerial()
|
||||
if e.networkSerial > serial {
|
||||
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
|
||||
return nil
|
||||
}
|
||||
|
||||
//remove peers that are no longer available for us
|
||||
toRemove := []string{}
|
||||
for p := range e.peerConns {
|
||||
if _, ok := remotePeerMap[p]; !ok {
|
||||
toRemove = append(toRemove, p)
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
// cleanup request, most likely our peer has been deleted
|
||||
if networkMap.GetRemotePeersIsEmpty() {
|
||||
err := e.removeAllPeers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := e.removePeers(toRemove)
|
||||
} else {
|
||||
err := e.removePeers(networkMap.GetRemotePeers())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add new peers
|
||||
for _, p := range remotePeers {
|
||||
err = e.addNewPeers(networkMap.GetRemotePeers())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
e.networkSerial = serial
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNewPeers finds and adds peers that were not know before but arrived from the Management service with the update
|
||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
for _, p := range peersUpdate {
|
||||
peerKey := p.GetWgPubKey()
|
||||
peerIPs := p.GetAllowedIps()
|
||||
if _, ok := e.peerConns[peerKey]; !ok {
|
||||
|
@ -37,6 +37,232 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
|
||||
// test setup
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
engine := NewEngine(&signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||
WgIfaceName: "utun100",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, cancel, ctx)
|
||||
|
||||
type testCase struct {
|
||||
idx int
|
||||
networkMap *mgmtProto.NetworkMap
|
||||
expectedLen int
|
||||
expectedPeers []string
|
||||
expectedSerial uint64
|
||||
}
|
||||
|
||||
peer1 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.10/24"},
|
||||
}
|
||||
|
||||
peer2 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.11/24"},
|
||||
}
|
||||
|
||||
peer3 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.12/24"},
|
||||
}
|
||||
|
||||
// 1st case - new peer and network map has Serial grater than local => apply the update
|
||||
case1 := testCase{
|
||||
idx: 1,
|
||||
networkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 1,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{
|
||||
peer1,
|
||||
},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
expectedLen: 1,
|
||||
expectedPeers: []string{peer1.GetWgPubKey()},
|
||||
expectedSerial: 1,
|
||||
}
|
||||
|
||||
// 2nd case - one extra peer added and network map has Serial grater than local => apply the update
|
||||
case2 := testCase{
|
||||
idx: 2,
|
||||
networkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 2,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{
|
||||
peer1, peer2,
|
||||
},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
expectedLen: 2,
|
||||
expectedPeers: []string{peer1.GetWgPubKey(), peer2.GetWgPubKey()},
|
||||
expectedSerial: 2,
|
||||
}
|
||||
|
||||
// 3rd case - an update with 3 peers and Serial lower than the current serial of the engine => ignore the update
|
||||
case3 := testCase{
|
||||
idx: 3,
|
||||
networkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 0,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{
|
||||
peer1, peer2, peer3,
|
||||
},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
expectedLen: 2,
|
||||
expectedPeers: []string{peer1.GetWgPubKey(), peer2.GetWgPubKey()},
|
||||
expectedSerial: 2,
|
||||
}
|
||||
|
||||
// 4th case - an update with 2 peers (1 new and 1 old) => apply the update removing old peer and adding a new one
|
||||
case4 := testCase{
|
||||
idx: 3,
|
||||
networkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 4,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{
|
||||
peer2, peer3,
|
||||
},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
expectedLen: 2,
|
||||
expectedPeers: []string{peer2.GetWgPubKey(), peer3.GetWgPubKey()},
|
||||
expectedSerial: 4,
|
||||
}
|
||||
|
||||
// 5th case - an update with all peers to be removed
|
||||
case5 := testCase{
|
||||
idx: 3,
|
||||
networkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 5,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
},
|
||||
expectedLen: 0,
|
||||
expectedPeers: nil,
|
||||
expectedSerial: 5,
|
||||
}
|
||||
|
||||
for _, c := range []testCase{case1, case2, case3, case4, case5} {
|
||||
err = engine.updateNetworkMap(c.networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(engine.peerConns) != c.expectedLen {
|
||||
t.Errorf("case %d expecting Engine.peerConns to be of size %d, got %d", c.idx, c.expectedLen, len(engine.peerConns))
|
||||
}
|
||||
|
||||
if engine.networkSerial != c.expectedSerial {
|
||||
t.Errorf("case %d expecting Engine.networkSerial to be equal to %d, actual %d", c.idx, c.expectedSerial, engine.networkSerial)
|
||||
}
|
||||
|
||||
for _, p := range c.expectedPeers {
|
||||
if _, ok := engine.peerConns[p]; !ok {
|
||||
t.Errorf("case %d expecting Engine.peerConns to contain peer %s", c.idx, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestEngine_Sync(t *testing.T) {
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := NewEngine(&signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{
|
||||
WgIfaceName: "utun100",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, cancel, ctx)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peer1 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
AllowedIps: []string{"100.64.0.10/24"},
|
||||
}
|
||||
peer2 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.11/24"},
|
||||
}
|
||||
peer3 := &mgmtProto.RemotePeerConfig{
|
||||
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||
AllowedIps: []string{"100.64.0.12/24"},
|
||||
}
|
||||
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||
updates <- &mgmtProto.SyncResponse{
|
||||
NetworkMap: &mgmtProto.NetworkMap{
|
||||
Serial: 10,
|
||||
PeerConfig: nil,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
RemotePeersIsEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
timeout := time.After(time.Second * 2)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout while waiting for test to finish")
|
||||
default:
|
||||
}
|
||||
|
||||
if len(engine.GetPeers()) == 3 && engine.networkSerial == 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestEngine_MultiplePeers(t *testing.T) {
|
||||
|
||||
//log.SetLevel(log.DebugLevel)
|
||||
@ -58,23 +284,14 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sport := 10010
|
||||
signalServer, err := startSignal(sport)
|
||||
sigServer, err := startSignal(sport)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer signalServer.Stop()
|
||||
defer sigServer.Stop()
|
||||
mport := 33081
|
||||
mgmtServer, err := startManagement(mport, &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
Signal: &server.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
Datadir: dir,
|
||||
HttpConfig: nil,
|
||||
})
|
||||
mgmtServer, err := startManagement(mport, dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -201,7 +418,18 @@ func startSignal(port int) (*grpc.Server, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func startManagement(port int, config *server.Config) (*grpc.Server, error) {
|
||||
func startManagement(port int, dataDir string) (*grpc.Server, error) {
|
||||
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
Signal: &server.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
Datadir: dataDir,
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
if err != nil {
|
||||
|
@ -419,3 +419,7 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (conn *Conn) GetKey() string {
|
||||
return conn.config.Key
|
||||
}
|
||||
|
@ -1,253 +1,15 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/client/system"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
key wgtypes.Key
|
||||
realClient proto.ManagementServiceClient
|
||||
ctx context.Context
|
||||
conn *grpc.ClientConn
|
||||
}
|
||||
|
||||
// NewClient creates a new client to Management service
|
||||
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*Client, error) {
|
||||
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
|
||||
if tlsEnabled {
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(
|
||||
mgmCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 15 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed creating connection to Management Service %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
realClient := proto.NewManagementServiceClient(conn)
|
||||
|
||||
return &Client{
|
||||
key: ourPrivateKey,
|
||||
realClient: realClient,
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes connection to the Management Service
|
||||
func (c *Client) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
//defaultBackoff is a basic backoff mechanism for general issues
|
||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: backoff.DefaultRandomizationFactor,
|
||||
Multiplier: backoff.DefaultMultiplier,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// ready indicates whether the client is okay and ready to be used
|
||||
// for now it just checks whether gRPC connection to the service is ready
|
||||
func (c *Client) ready() bool {
|
||||
return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle
|
||||
}
|
||||
|
||||
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
||||
// Blocking request. The result will be sent via msgHandler callback function
|
||||
func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
|
||||
var backOff = defaultBackoff(c.ctx)
|
||||
|
||||
operation := func() error {
|
||||
|
||||
log.Debugf("management connection state %v", c.conn.GetState())
|
||||
|
||||
if !c.ready() {
|
||||
return fmt.Errorf("no connection to management")
|
||||
}
|
||||
|
||||
// todo we already have it since we did the Login, maybe cache it locally?
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed getting Management Service public key: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := c.connectToStream(*serverPubKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed to open Management Service stream: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("connected to the Management Service stream")
|
||||
|
||||
// blocking until error
|
||||
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
|
||||
if err != nil {
|
||||
backOff.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Warnf("exiting Management Service connection retry loop due to unrecoverable error: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
||||
req := &proto.SyncRequest{}
|
||||
|
||||
myPrivateKey := c.key
|
||||
myPublicKey := myPrivateKey.PublicKey()
|
||||
|
||||
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed encrypting message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
|
||||
return c.realClient.Sync(c.ctx, syncReq)
|
||||
}
|
||||
|
||||
func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
for {
|
||||
update, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
log.Errorf("Management stream has been closed by server: %s", err)
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
log.Warnf("disconnected from Management Service sync stream: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("got an update message from Management Service")
|
||||
decryptedResp := &proto.SyncResponse{}
|
||||
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed decrypting update message from Management Service: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = msgHandler(decryptedResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server)
|
||||
func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management")
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
|
||||
defer cancel()
|
||||
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverKey, err := wgtypes.ParseKey(resp.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &serverKey, nil
|
||||
}
|
||||
|
||||
func (c *Client) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management")
|
||||
}
|
||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to encrypt message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
|
||||
defer cancel()
|
||||
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: loginReq,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginResp := &proto.LoginResponse{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt registration message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
||||
// Takes care of encrypting and decrypting messages.
|
||||
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
||||
func (c *Client) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) {
|
||||
gi := system.GetInfo()
|
||||
meta := &proto.PeerSystemMeta{
|
||||
Hostname: gi.Hostname,
|
||||
GoOS: gi.GoOS,
|
||||
OS: gi.OS,
|
||||
Core: gi.OSVersion,
|
||||
Platform: gi.Platform,
|
||||
Kernel: gi.Kernel,
|
||||
WiretrusteeVersion: "",
|
||||
}
|
||||
log.Debugf("detected system %v", meta)
|
||||
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: meta})
|
||||
}
|
||||
|
||||
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
||||
func (c *Client) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) {
|
||||
return c.login(serverKey, &proto.LoginRequest{})
|
||||
type Client interface {
|
||||
io.Closer
|
||||
Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
||||
GetServerPublicKey() (*wgtypes.Key, error)
|
||||
Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error)
|
||||
Login(serverKey wgtypes.Key) (*proto.LoginResponse, error)
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var tested *Client
|
||||
var tested *GrpcClient
|
||||
var serverAddr string
|
||||
|
||||
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
253
management/client/grpc.go
Normal file
253
management/client/grpc.go
Normal file
@ -0,0 +1,253 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/client/system"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
type GrpcClient struct {
|
||||
key wgtypes.Key
|
||||
realClient proto.ManagementServiceClient
|
||||
ctx context.Context
|
||||
conn *grpc.ClientConn
|
||||
}
|
||||
|
||||
// NewClient creates a new client to Management service
|
||||
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
||||
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
|
||||
if tlsEnabled {
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(
|
||||
mgmCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 15 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed creating connection to Management Service %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
realClient := proto.NewManagementServiceClient(conn)
|
||||
|
||||
return &GrpcClient{
|
||||
key: ourPrivateKey,
|
||||
realClient: realClient,
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes connection to the Management Service
|
||||
func (c *GrpcClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
//defaultBackoff is a basic backoff mechanism for general issues
|
||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: backoff.DefaultRandomizationFactor,
|
||||
Multiplier: backoff.DefaultMultiplier,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// ready indicates whether the client is okay and ready to be used
|
||||
// for now it just checks whether gRPC connection to the service is ready
|
||||
func (c *GrpcClient) ready() bool {
|
||||
return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle
|
||||
}
|
||||
|
||||
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
||||
// Blocking request. The result will be sent via msgHandler callback function
|
||||
func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
|
||||
var backOff = defaultBackoff(c.ctx)
|
||||
|
||||
operation := func() error {
|
||||
|
||||
log.Debugf("management connection state %v", c.conn.GetState())
|
||||
|
||||
if !c.ready() {
|
||||
return fmt.Errorf("no connection to management")
|
||||
}
|
||||
|
||||
// todo we already have it since we did the Login, maybe cache it locally?
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed getting Management Service public key: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := c.connectToStream(*serverPubKey)
|
||||
if err != nil {
|
||||
log.Errorf("failed to open Management Service stream: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("connected to the Management Service stream")
|
||||
|
||||
// blocking until error
|
||||
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
|
||||
if err != nil {
|
||||
backOff.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Warnf("exiting Management Service connection retry loop due to unrecoverable error: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
||||
req := &proto.SyncRequest{}
|
||||
|
||||
myPrivateKey := c.key
|
||||
myPublicKey := myPrivateKey.PublicKey()
|
||||
|
||||
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed encrypting message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
|
||||
return c.realClient.Sync(c.ctx, syncReq)
|
||||
}
|
||||
|
||||
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
for {
|
||||
update, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
log.Errorf("Management stream has been closed by server: %s", err)
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
log.Warnf("disconnected from Management Service sync stream: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("got an update message from Management Service")
|
||||
decryptedResp := &proto.SyncResponse{}
|
||||
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed decrypting update message from Management Service: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = msgHandler(decryptedResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server)
|
||||
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management")
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
|
||||
defer cancel()
|
||||
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverKey, err := wgtypes.ParseKey(resp.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &serverKey, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management")
|
||||
}
|
||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to encrypt message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
|
||||
defer cancel()
|
||||
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: loginReq,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginResp := &proto.LoginResponse{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt registration message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
||||
// Takes care of encrypting and decrypting messages.
|
||||
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
||||
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) {
|
||||
gi := system.GetInfo()
|
||||
meta := &proto.PeerSystemMeta{
|
||||
Hostname: gi.Hostname,
|
||||
GoOS: gi.GoOS,
|
||||
OS: gi.OS,
|
||||
Core: gi.OSVersion,
|
||||
Platform: gi.Platform,
|
||||
Kernel: gi.Kernel,
|
||||
WiretrusteeVersion: "",
|
||||
}
|
||||
log.Debugf("detected system %v", meta)
|
||||
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: meta})
|
||||
}
|
||||
|
||||
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) {
|
||||
return c.login(serverKey, &proto.LoginRequest{})
|
||||
}
|
49
management/client/mock.go
Normal file
49
management/client/mock.go
Normal file
@ -0,0 +1,49 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type MockClient struct {
|
||||
CloseFunc func() error
|
||||
SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error
|
||||
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
||||
RegisterFunc func(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error)
|
||||
LoginFunc func(serverKey wgtypes.Key) (*proto.LoginResponse, error)
|
||||
}
|
||||
|
||||
func (m *MockClient) Close() error {
|
||||
if m.CloseFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
if m.SyncFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return m.SyncFunc(msgHandler)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
if m.GetServerPublicKeyFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.GetServerPublicKeyFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) {
|
||||
if m.RegisterFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.RegisterFunc(serverKey, setupKey)
|
||||
}
|
||||
|
||||
func (m *MockClient) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) {
|
||||
if m.LoginFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.LoginFunc(serverKey)
|
||||
}
|
@ -1,26 +1,11 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
|
||||
@ -31,317 +16,15 @@ type Status string
|
||||
const StreamConnected Status = "Connected"
|
||||
const StreamDisconnected Status = "Disconnected"
|
||||
|
||||
// Client Wraps the Signal Exchange Service gRpc client
|
||||
type Client struct {
|
||||
key wgtypes.Key
|
||||
realClient proto.SignalExchangeClient
|
||||
signalConn *grpc.ClientConn
|
||||
ctx context.Context
|
||||
stream proto.SignalExchange_ConnectStreamClient
|
||||
// connectedCh used to notify goroutines waiting for the connection to the Signal stream
|
||||
connectedCh chan struct{}
|
||||
mux sync.Mutex
|
||||
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
|
||||
status Status
|
||||
}
|
||||
|
||||
func (c *Client) StreamConnected() bool {
|
||||
return c.status == StreamConnected
|
||||
}
|
||||
|
||||
func (c *Client) GetStatus() Status {
|
||||
return c.status
|
||||
}
|
||||
|
||||
// Close Closes underlying connections to the Signal Exchange
|
||||
func (c *Client) Close() error {
|
||||
return c.signalConn.Close()
|
||||
}
|
||||
|
||||
// NewClient creates a new Signal client
|
||||
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*Client, error) {
|
||||
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
|
||||
if tlsEnabled {
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
||||
}
|
||||
|
||||
sigCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(
|
||||
sigCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 15 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect to the signalling server %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Client{
|
||||
realClient: proto.NewSignalExchangeClient(conn),
|
||||
ctx: ctx,
|
||||
signalConn: conn,
|
||||
key: key,
|
||||
mux: sync.Mutex{},
|
||||
status: StreamDisconnected,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//defaultBackoff is a basic backoff mechanism for general issues
|
||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: backoff.DefaultRandomizationFactor,
|
||||
Multiplier: backoff.DefaultMultiplier,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
|
||||
}
|
||||
|
||||
// Receive Connects to the Signal Exchange message stream and starts receiving messages.
|
||||
// The messages will be handled by msgHandler function provided.
|
||||
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
|
||||
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
|
||||
func (c *Client) Receive(msgHandler func(msg *proto.Message) error) error {
|
||||
|
||||
var backOff = defaultBackoff(c.ctx)
|
||||
|
||||
operation := func() error {
|
||||
|
||||
c.notifyStreamDisconnected()
|
||||
|
||||
log.Debugf("signal connection state %v", c.signalConn.GetState())
|
||||
if !c.Ready() {
|
||||
return fmt.Errorf("no connection to signal")
|
||||
}
|
||||
|
||||
// connect to Signal stream identifying ourselves with a public Wireguard key
|
||||
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
||||
stream, err := c.connect(c.key.PublicKey().String())
|
||||
if err != nil {
|
||||
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.notifyStreamConnected()
|
||||
|
||||
log.Infof("connected to the Signal Service stream")
|
||||
|
||||
// start receiving messages from the Signal stream (from other peers through signal)
|
||||
err = c.receive(stream, msgHandler)
|
||||
if err != nil {
|
||||
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
||||
backOff.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *Client) notifyStreamDisconnected() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
c.status = StreamDisconnected
|
||||
}
|
||||
|
||||
func (c *Client) notifyStreamConnected() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
c.status = StreamConnected
|
||||
if c.connectedCh != nil {
|
||||
// there are goroutines waiting on this channel -> release them
|
||||
close(c.connectedCh)
|
||||
c.connectedCh = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) getStreamStatusChan() <-chan struct{} {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
if c.connectedCh == nil {
|
||||
c.connectedCh = make(chan struct{})
|
||||
}
|
||||
return c.connectedCh
|
||||
}
|
||||
|
||||
func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
||||
c.stream = nil
|
||||
|
||||
// add key fingerprint to the request header to be identified on the server side
|
||||
md := metadata.New(map[string]string{proto.HeaderId: key})
|
||||
ctx := metadata.NewOutgoingContext(c.ctx, md)
|
||||
|
||||
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
|
||||
|
||||
c.stream = stream
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// blocks
|
||||
header, err := c.stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
registered := header.Get(proto.HeaderRegistered)
|
||||
if len(registered) == 0 {
|
||||
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// Ready indicates whether the client is okay and Ready to be used
|
||||
// for now it just checks whether gRPC connection to the service is in state Ready
|
||||
func (c *Client) Ready() bool {
|
||||
return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle
|
||||
}
|
||||
|
||||
// WaitStreamConnected waits until the client is connected to the Signal stream
|
||||
func (c *Client) WaitStreamConnected() {
|
||||
|
||||
if c.status == StreamConnected {
|
||||
return
|
||||
}
|
||||
|
||||
ch := c.getStreamStatusChan()
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
case <-ch:
|
||||
}
|
||||
}
|
||||
|
||||
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
|
||||
// The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
|
||||
// Client.connWg can be used to wait
|
||||
func (c *Client) SendToStream(msg *proto.EncryptedMessage) error {
|
||||
if !c.Ready() {
|
||||
return fmt.Errorf("no connection to signal")
|
||||
}
|
||||
if c.stream == nil {
|
||||
return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages")
|
||||
}
|
||||
|
||||
err := c.stream.Send(msg)
|
||||
if err != nil {
|
||||
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||
func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) {
|
||||
remoteKey, err := wgtypes.ParseKey(msg.GetKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body := &proto.Body{}
|
||||
err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &proto.Message{
|
||||
Key: msg.Key,
|
||||
RemoteKey: msg.RemoteKey,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||
func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
|
||||
|
||||
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
Key: msg.GetKey(),
|
||||
RemoteKey: msg.GetRemoteKey(),
|
||||
Body: encryptedBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Send sends a message to the remote Peer through the Signal Exchange.
|
||||
func (c *Client) Send(msg *proto.Message) error {
|
||||
|
||||
if !c.Ready() {
|
||||
return fmt.Errorf("no connection to signal")
|
||||
}
|
||||
|
||||
encryptedMessage, err := c.encryptMessage(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, err = c.realClient.Send(ctx, encryptedMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// receive receives messages from other peers coming through the Signal Exchange
|
||||
func (c *Client) receive(stream proto.SignalExchange_ConnectStreamClient,
|
||||
msgHandler func(msg *proto.Message) error) error {
|
||||
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
log.Warnf("stream canceled (usually indicates shutdown)")
|
||||
return err
|
||||
} else if s.Code() == codes.Unavailable {
|
||||
log.Warnf("Signal Service is unavailable")
|
||||
return err
|
||||
} else if err == io.EOF {
|
||||
log.Warnf("Signal Service stream closed by server")
|
||||
return err
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key)
|
||||
|
||||
decryptedMessage, err := c.decryptMessage(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
||||
}
|
||||
|
||||
err = msgHandler(decryptedMessage)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
||||
//todo send something??
|
||||
}
|
||||
}
|
||||
type Client interface {
|
||||
io.Closer
|
||||
StreamConnected() bool
|
||||
GetStatus() Status
|
||||
Receive(msgHandler func(msg *proto.Message) error) error
|
||||
Ready() bool
|
||||
WaitStreamConnected()
|
||||
SendToStream(msg *proto.EncryptedMessage) error
|
||||
Send(msg *proto.Message) error
|
||||
}
|
||||
|
||||
// UnMarshalCredential parses the credentials from the message and returns a Credential instance
|
||||
@ -369,7 +52,7 @@ func MarshalCredential(myKey wgtypes.Key, remoteKey wgtypes.Key, credential *Cre
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Credential is an instance of a Client's Credential
|
||||
// Credential is an instance of a GrpcClient's Credential
|
||||
type Credential struct {
|
||||
UFrag string
|
||||
Pwd string
|
||||
|
@ -17,7 +17,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var _ = Describe("Client", func() {
|
||||
var _ = Describe("GrpcClient", func() {
|
||||
|
||||
var (
|
||||
addr string
|
||||
@ -160,7 +160,7 @@ var _ = Describe("Client", func() {
|
||||
|
||||
})
|
||||
|
||||
func createSignalClient(addr string, key wgtypes.Key) *Client {
|
||||
func createSignalClient(addr string, key wgtypes.Key) *GrpcClient {
|
||||
var sigTLSEnabled = false
|
||||
client, err := NewClient(context.Background(), addr, key, sigTLSEnabled)
|
||||
if err != nil {
|
||||
|
336
signal/client/grpc.go
Normal file
336
signal/client/grpc.go
Normal file
@ -0,0 +1,336 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wiretrustee/wiretrustee/encryption"
|
||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GrpcClient Wraps the Signal Exchange Service gRpc client
|
||||
type GrpcClient struct {
|
||||
key wgtypes.Key
|
||||
realClient proto.SignalExchangeClient
|
||||
signalConn *grpc.ClientConn
|
||||
ctx context.Context
|
||||
stream proto.SignalExchange_ConnectStreamClient
|
||||
// connectedCh used to notify goroutines waiting for the connection to the Signal stream
|
||||
connectedCh chan struct{}
|
||||
mux sync.Mutex
|
||||
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
|
||||
status Status
|
||||
}
|
||||
|
||||
func (c *GrpcClient) StreamConnected() bool {
|
||||
return c.status == StreamConnected
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetStatus() Status {
|
||||
return c.status
|
||||
}
|
||||
|
||||
// Close Closes underlying connections to the Signal Exchange
|
||||
func (c *GrpcClient) Close() error {
|
||||
return c.signalConn.Close()
|
||||
}
|
||||
|
||||
// NewClient creates a new Signal client
|
||||
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
||||
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
|
||||
if tlsEnabled {
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
||||
}
|
||||
|
||||
sigCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(
|
||||
sigCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 15 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect to the signalling server %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GrpcClient{
|
||||
realClient: proto.NewSignalExchangeClient(conn),
|
||||
ctx: ctx,
|
||||
signalConn: conn,
|
||||
key: key,
|
||||
mux: sync.Mutex{},
|
||||
status: StreamDisconnected,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//defaultBackoff is a basic backoff mechanism for general issues
|
||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: backoff.DefaultRandomizationFactor,
|
||||
Multiplier: backoff.DefaultMultiplier,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
|
||||
}
|
||||
|
||||
// Receive Connects to the Signal Exchange message stream and starts receiving messages.
|
||||
// The messages will be handled by msgHandler function provided.
|
||||
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
|
||||
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
|
||||
func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
||||
|
||||
var backOff = defaultBackoff(c.ctx)
|
||||
|
||||
operation := func() error {
|
||||
|
||||
c.notifyStreamDisconnected()
|
||||
|
||||
log.Debugf("signal connection state %v", c.signalConn.GetState())
|
||||
if !c.Ready() {
|
||||
return fmt.Errorf("no connection to signal")
|
||||
}
|
||||
|
||||
// connect to Signal stream identifying ourselves with a public Wireguard key
|
||||
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
||||
stream, err := c.connect(c.key.PublicKey().String())
|
||||
if err != nil {
|
||||
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.notifyStreamConnected()
|
||||
|
||||
log.Infof("connected to the Signal Service stream")
|
||||
|
||||
// start receiving messages from the Signal stream (from other peers through signal)
|
||||
err = c.receive(stream, msgHandler)
|
||||
if err != nil {
|
||||
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
||||
backOff.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *GrpcClient) notifyStreamDisconnected() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
c.status = StreamDisconnected
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyStreamConnected() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
c.status = StreamConnected
|
||||
if c.connectedCh != nil {
|
||||
// there are goroutines waiting on this channel -> release them
|
||||
close(c.connectedCh)
|
||||
c.connectedCh = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
if c.connectedCh == nil {
|
||||
c.connectedCh = make(chan struct{})
|
||||
}
|
||||
return c.connectedCh
|
||||
}
|
||||
|
||||
func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
||||
c.stream = nil
|
||||
|
||||
// add key fingerprint to the request header to be identified on the server side
|
||||
md := metadata.New(map[string]string{proto.HeaderId: key})
|
||||
ctx := metadata.NewOutgoingContext(c.ctx, md)
|
||||
|
||||
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
|
||||
|
||||
c.stream = stream
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// blocks
|
||||
header, err := c.stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
registered := header.Get(proto.HeaderRegistered)
|
||||
if len(registered) == 0 {
|
||||
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// Ready indicates whether the client is okay and Ready to be used
|
||||
// for now it just checks whether gRPC connection to the service is in state Ready
|
||||
func (c *GrpcClient) Ready() bool {
|
||||
return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle
|
||||
}
|
||||
|
||||
// WaitStreamConnected waits until the client is connected to the Signal stream
|
||||
func (c *GrpcClient) WaitStreamConnected() {
|
||||
|
||||
if c.status == StreamConnected {
|
||||
return
|
||||
}
|
||||
|
||||
ch := c.getStreamStatusChan()
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
case <-ch:
|
||||
}
|
||||
}
|
||||
|
||||
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
|
||||
// The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
|
||||
// GrpcClient.connWg can be used to wait
|
||||
func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error {
|
||||
if !c.Ready() {
|
||||
return fmt.Errorf("no connection to signal")
|
||||
}
|
||||
if c.stream == nil {
|
||||
return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call GrpcClient.Receive before sending messages")
|
||||
}
|
||||
|
||||
err := c.stream.Send(msg)
|
||||
if err != nil {
|
||||
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||
func (c *GrpcClient) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) {
|
||||
remoteKey, err := wgtypes.ParseKey(msg.GetKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body := &proto.Body{}
|
||||
err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &proto.Message{
|
||||
Key: msg.Key,
|
||||
RemoteKey: msg.RemoteKey,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||
func (c *GrpcClient) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
|
||||
|
||||
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
Key: msg.GetKey(),
|
||||
RemoteKey: msg.GetRemoteKey(),
|
||||
Body: encryptedBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Send sends a message to the remote Peer through the Signal Exchange.
|
||||
func (c *GrpcClient) Send(msg *proto.Message) error {
|
||||
|
||||
if !c.Ready() {
|
||||
return fmt.Errorf("no connection to signal")
|
||||
}
|
||||
|
||||
encryptedMessage, err := c.encryptMessage(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, err = c.realClient.Send(ctx, encryptedMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// receive receives messages from other peers coming through the Signal Exchange
|
||||
func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
|
||||
msgHandler func(msg *proto.Message) error) error {
|
||||
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
log.Warnf("stream canceled (usually indicates shutdown)")
|
||||
return err
|
||||
} else if s.Code() == codes.Unavailable {
|
||||
log.Warnf("Signal Service is unavailable")
|
||||
return err
|
||||
} else if err == io.EOF {
|
||||
log.Warnf("Signal Service stream closed by server")
|
||||
return err
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key)
|
||||
|
||||
decryptedMessage, err := c.decryptMessage(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
||||
}
|
||||
|
||||
err = msgHandler(decryptedMessage)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
||||
//todo send something??
|
||||
}
|
||||
}
|
||||
}
|
72
signal/client/mock.go
Normal file
72
signal/client/mock.go
Normal file
@ -0,0 +1,72 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||
)
|
||||
|
||||
type MockClient struct {
|
||||
CloseFunc func() error
|
||||
GetStatusFunc func() Status
|
||||
StreamConnectedFunc func() bool
|
||||
ReadyFunc func() bool
|
||||
WaitStreamConnectedFunc func()
|
||||
ReceiveFunc func(msgHandler func(msg *proto.Message) error) error
|
||||
SendToStreamFunc func(msg *proto.EncryptedMessage) error
|
||||
SendFunc func(msg *proto.Message) error
|
||||
}
|
||||
|
||||
func (sm *MockClient) Close() error {
|
||||
if sm.CloseFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return sm.CloseFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) GetStatus() Status {
|
||||
if sm.GetStatusFunc == nil {
|
||||
return ""
|
||||
}
|
||||
return sm.GetStatusFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) StreamConnected() bool {
|
||||
if sm.StreamConnectedFunc == nil {
|
||||
return false
|
||||
}
|
||||
return sm.StreamConnectedFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) Ready() bool {
|
||||
if sm.ReadyFunc == nil {
|
||||
return false
|
||||
}
|
||||
return sm.ReadyFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) WaitStreamConnected() {
|
||||
if sm.WaitStreamConnectedFunc == nil {
|
||||
return
|
||||
}
|
||||
sm.WaitStreamConnectedFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
||||
if sm.ReceiveFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return sm.ReceiveFunc(msgHandler)
|
||||
}
|
||||
|
||||
func (sm *MockClient) SendToStream(msg *proto.EncryptedMessage) error {
|
||||
if sm.SendToStreamFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return sm.SendToStreamFunc(msg)
|
||||
}
|
||||
|
||||
func (sm *MockClient) Send(msg *proto.Message) error {
|
||||
if sm.SendFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return sm.SendFunc(msg)
|
||||
}
|
16
util/common.go
Normal file
16
util/common.go
Normal file
@ -0,0 +1,16 @@
|
||||
package util
|
||||
|
||||
// SliceDiff returns the elements in slice `x` that are not in slice `y`
|
||||
func SliceDiff(x, y []string) []string {
|
||||
mapY := make(map[string]struct{}, len(y))
|
||||
for _, val := range y {
|
||||
mapY[val] = struct{}{}
|
||||
}
|
||||
var diff []string
|
||||
for _, val := range x {
|
||||
if _, found := mapY[val]; !found {
|
||||
diff = append(diff, val)
|
||||
}
|
||||
}
|
||||
return diff
|
||||
}
|
Loading…
Reference in New Issue
Block a user