mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-12 13:38:37 +02:00
Support new Management service protocol (NetworkMap) (#193)
* feature: support new management service protocol * chore: add more logging to track networkmap serial * refactor: organize peer update code in engine * chore: fix lint issues * refactor: extract Signal client interface * test: add signal client mock * refactor: introduce Management Service client interface * chore: place management and signal clients mocks to respective packages * test: add Serial test to the engine * fix: lint issues * test: unit tests for a networkMapUpdate * test: unit tests Sync update
This commit is contained in:
parent
9a3fba3fa3
commit
5db130a12e
@ -82,7 +82,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow.
|
// loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow.
|
||||||
func loginPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string) (*mgmProto.LoginResponse, error) {
|
func loginPeer(serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
loginResp, err := client.Login(serverPublicKey)
|
loginResp, err := client.Login(serverPublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -101,7 +101,7 @@ func loginPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string)
|
|||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
// Otherwise tries to register with the provided setupKey via command line.
|
// Otherwise tries to register with the provided setupKey via command line.
|
||||||
func registerPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string) (*mgmProto.LoginResponse, error) {
|
func registerPeer(serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if setupKey == "" {
|
if setupKey == "" {
|
||||||
|
@ -83,7 +83,7 @@ func createEngineConfig(key wgtypes.Key, config *internal.Config, peerConfig *mg
|
|||||||
}
|
}
|
||||||
|
|
||||||
// connectToSignal creates Signal Service client and established a connection
|
// connectToSignal creates Signal Service client and established a connection
|
||||||
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.Client, error) {
|
func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
|
||||||
var sigTLSEnabled bool
|
var sigTLSEnabled bool
|
||||||
if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS {
|
if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS {
|
||||||
sigTLSEnabled = true
|
sigTLSEnabled = true
|
||||||
@ -101,7 +101,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// connectToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
// connectToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
|
||||||
func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*mgm.Client, *mgmProto.LoginResponse, error) {
|
func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*mgm.GrpcClient, *mgmProto.LoginResponse, error) {
|
||||||
log.Debugf("connecting to management server %s", managementAddr)
|
log.Debugf("connecting to management server %s", managementAddr)
|
||||||
client, err := mgm.NewClient(ctx, managementAddr, ourPrivateKey, tlsEnabled)
|
client, err := mgm.NewClient(ctx, managementAddr, ourPrivateKey, tlsEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
mgmProto "github.com/wiretrustee/wiretrustee/management/proto"
|
mgmProto "github.com/wiretrustee/wiretrustee/management/proto"
|
||||||
signal "github.com/wiretrustee/wiretrustee/signal/client"
|
signal "github.com/wiretrustee/wiretrustee/signal/client"
|
||||||
sProto "github.com/wiretrustee/wiretrustee/signal/proto"
|
sProto "github.com/wiretrustee/wiretrustee/signal/proto"
|
||||||
|
"github.com/wiretrustee/wiretrustee/util"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strings"
|
"strings"
|
||||||
@ -44,9 +45,9 @@ type EngineConfig struct {
|
|||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
type Engine struct {
|
type Engine struct {
|
||||||
// signal is a Signal Service client
|
// signal is a Signal Service client
|
||||||
signal *signal.Client
|
signal signal.Client
|
||||||
// mgmClient is a Management Service client
|
// mgmClient is a Management Service client
|
||||||
mgmClient *mgm.Client
|
mgmClient mgm.Client
|
||||||
// peerConns is a map that holds all the peers that are known to this peer
|
// peerConns is a map that holds all the peers that are known to this peer
|
||||||
peerConns map[string]*peer.Conn
|
peerConns map[string]*peer.Conn
|
||||||
|
|
||||||
@ -64,6 +65,9 @@ type Engine struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
|
||||||
wgInterface iface.WGIface
|
wgInterface iface.WGIface
|
||||||
|
|
||||||
|
// networkSerial is the latest Serial (state ID) of the network sent by the Management service
|
||||||
|
networkSerial uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@ -73,17 +77,18 @@ type Peer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewEngine creates a new Connection Engine
|
// NewEngine creates a new Connection Engine
|
||||||
func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *EngineConfig, cancel context.CancelFunc, ctx context.Context) *Engine {
|
func NewEngine(signalClient signal.Client, mgmClient mgm.Client, config *EngineConfig, cancel context.CancelFunc, ctx context.Context) *Engine {
|
||||||
return &Engine{
|
return &Engine{
|
||||||
signal: signalClient,
|
signal: signalClient,
|
||||||
mgmClient: mgmClient,
|
mgmClient: mgmClient,
|
||||||
peerConns: map[string]*peer.Conn{},
|
peerConns: map[string]*peer.Conn{},
|
||||||
syncMsgMux: &sync.Mutex{},
|
syncMsgMux: &sync.Mutex{},
|
||||||
config: config,
|
config: config,
|
||||||
STUNs: []*ice.URL{},
|
STUNs: []*ice.URL{},
|
||||||
TURNs: []*ice.URL{},
|
TURNs: []*ice.URL{},
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
networkSerial: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,7 +96,7 @@ func (e *Engine) Stop() error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
err := e.removeAllPeerConnections()
|
err := e.removeAllPeers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -146,8 +151,22 @@ func (e *Engine) Start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) removePeers(peers []string) error {
|
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service
|
||||||
for _, p := range peers {
|
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||||
|
|
||||||
|
currentPeers := make([]string, 0, len(e.peerConns))
|
||||||
|
for p := range e.peerConns {
|
||||||
|
currentPeers = append(currentPeers, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
newPeers := make([]string, 0, len(peersUpdate))
|
||||||
|
for _, p := range peersUpdate {
|
||||||
|
newPeers = append(newPeers, p.GetWgPubKey())
|
||||||
|
}
|
||||||
|
|
||||||
|
toRemove := util.SliceDiff(currentPeers, newPeers)
|
||||||
|
|
||||||
|
for _, p := range toRemove {
|
||||||
err := e.removePeer(p)
|
err := e.removePeer(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -157,7 +176,7 @@ func (e *Engine) removePeers(peers []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) removeAllPeerConnections() error {
|
func (e *Engine) removeAllPeers() error {
|
||||||
log.Debugf("removing all peer connections")
|
log.Debugf("removing all peer connections")
|
||||||
for p := range e.peerConns {
|
for p := range e.peerConns {
|
||||||
err := e.removePeer(p)
|
err := e.removePeer(p)
|
||||||
@ -189,6 +208,16 @@ func (e *Engine) GetPeerConnectionStatus(peerKey string) peer.ConnStatus {
|
|||||||
|
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
func (e *Engine) GetPeers() []string {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
peers := []string{}
|
||||||
|
for s := range e.peerConns {
|
||||||
|
peers = append(peers, s)
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
// GetConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
// GetConnectedPeers returns a connection Status or nil if peer connection wasn't found
|
||||||
func (e *Engine) GetConnectedPeers() []string {
|
func (e *Engine) GetConnectedPeers() []string {
|
||||||
@ -205,7 +234,7 @@ func (e *Engine) GetConnectedPeers() []string {
|
|||||||
return peers
|
return peers
|
||||||
}
|
}
|
||||||
|
|
||||||
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s *signal.Client) error {
|
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
|
||||||
err := s.Send(&sProto.Message{
|
err := s.Send(&sProto.Message{
|
||||||
Key: myKey.PublicKey().String(),
|
Key: myKey.PublicKey().String(),
|
||||||
RemoteKey: remoteKey.String(),
|
RemoteKey: remoteKey.String(),
|
||||||
@ -223,7 +252,7 @@ func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtyp
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.Key, s *signal.Client, isAnswer bool) error {
|
func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error {
|
||||||
|
|
||||||
var t sProto.Body_Type
|
var t sProto.Body_Type
|
||||||
if isAnswer {
|
if isAnswer {
|
||||||
@ -246,37 +275,42 @@ func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.K
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if update.GetWiretrusteeConfig() != nil {
|
||||||
|
err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//todo update signal
|
||||||
|
}
|
||||||
|
|
||||||
|
if update.GetNetworkMap() != nil {
|
||||||
|
// only apply new changes and ignore old ones
|
||||||
|
err := e.updateNetworkMap(update.GetNetworkMap())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||||
func (e *Engine) receiveManagementEvents() {
|
func (e *Engine) receiveManagementEvents() {
|
||||||
go func() {
|
go func() {
|
||||||
err := e.mgmClient.Sync(func(update *mgmProto.SyncResponse) error {
|
err := e.mgmClient.Sync(func(update *mgmProto.SyncResponse) error {
|
||||||
e.syncMsgMux.Lock()
|
return e.handleSync(update)
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
if update.GetWiretrusteeConfig() != nil {
|
|
||||||
err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
//todo update signal
|
|
||||||
}
|
|
||||||
|
|
||||||
if update.GetRemotePeers() != nil || update.GetRemotePeersIsEmpty() {
|
|
||||||
// empty arrays are serialized by protobuf to null, but for our case empty array is a valid state.
|
|
||||||
err := e.updatePeers(update.GetRemotePeers())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
@ -327,27 +361,41 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(remotePeers))
|
|
||||||
remotePeerMap := make(map[string]struct{})
|
serial := networkMap.GetSerial()
|
||||||
for _, p := range remotePeers {
|
if e.networkSerial > serial {
|
||||||
remotePeerMap[p.GetWgPubKey()] = struct{}{}
|
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//remove peers that are no longer available for us
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
toRemove := []string{}
|
|
||||||
for p := range e.peerConns {
|
// cleanup request, most likely our peer has been deleted
|
||||||
if _, ok := remotePeerMap[p]; !ok {
|
if networkMap.GetRemotePeersIsEmpty() {
|
||||||
toRemove = append(toRemove, p)
|
err := e.removeAllPeers()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := e.removePeers(networkMap.GetRemotePeers())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = e.addNewPeers(networkMap.GetRemotePeers())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err := e.removePeers(toRemove)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// add new peers
|
e.networkSerial = serial
|
||||||
for _, p := range remotePeers {
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addNewPeers finds and adds peers that were not know before but arrived from the Management service with the update
|
||||||
|
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||||
|
for _, p := range peersUpdate {
|
||||||
peerKey := p.GetWgPubKey()
|
peerKey := p.GetWgPubKey()
|
||||||
peerIPs := p.GetAllowedIps()
|
peerIPs := p.GetAllowedIps()
|
||||||
if _, ok := e.peerConns[peerKey]; !ok {
|
if _, ok := e.peerConns[peerKey]; !ok {
|
||||||
|
@ -37,6 +37,232 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||||
|
|
||||||
|
// test setup
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
engine := NewEngine(&signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
|
||||||
|
WgIfaceName: "utun100",
|
||||||
|
WgAddr: "100.64.0.1/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
}, cancel, ctx)
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
idx int
|
||||||
|
networkMap *mgmtProto.NetworkMap
|
||||||
|
expectedLen int
|
||||||
|
expectedPeers []string
|
||||||
|
expectedSerial uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
peer1 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.10/24"},
|
||||||
|
}
|
||||||
|
|
||||||
|
peer2 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.11/24"},
|
||||||
|
}
|
||||||
|
|
||||||
|
peer3 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.12/24"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1st case - new peer and network map has Serial grater than local => apply the update
|
||||||
|
case1 := testCase{
|
||||||
|
idx: 1,
|
||||||
|
networkMap: &mgmtProto.NetworkMap{
|
||||||
|
Serial: 1,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{
|
||||||
|
peer1,
|
||||||
|
},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
},
|
||||||
|
expectedLen: 1,
|
||||||
|
expectedPeers: []string{peer1.GetWgPubKey()},
|
||||||
|
expectedSerial: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2nd case - one extra peer added and network map has Serial grater than local => apply the update
|
||||||
|
case2 := testCase{
|
||||||
|
idx: 2,
|
||||||
|
networkMap: &mgmtProto.NetworkMap{
|
||||||
|
Serial: 2,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{
|
||||||
|
peer1, peer2,
|
||||||
|
},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
},
|
||||||
|
expectedLen: 2,
|
||||||
|
expectedPeers: []string{peer1.GetWgPubKey(), peer2.GetWgPubKey()},
|
||||||
|
expectedSerial: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3rd case - an update with 3 peers and Serial lower than the current serial of the engine => ignore the update
|
||||||
|
case3 := testCase{
|
||||||
|
idx: 3,
|
||||||
|
networkMap: &mgmtProto.NetworkMap{
|
||||||
|
Serial: 0,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{
|
||||||
|
peer1, peer2, peer3,
|
||||||
|
},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
},
|
||||||
|
expectedLen: 2,
|
||||||
|
expectedPeers: []string{peer1.GetWgPubKey(), peer2.GetWgPubKey()},
|
||||||
|
expectedSerial: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4th case - an update with 2 peers (1 new and 1 old) => apply the update removing old peer and adding a new one
|
||||||
|
case4 := testCase{
|
||||||
|
idx: 3,
|
||||||
|
networkMap: &mgmtProto.NetworkMap{
|
||||||
|
Serial: 4,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{
|
||||||
|
peer2, peer3,
|
||||||
|
},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
},
|
||||||
|
expectedLen: 2,
|
||||||
|
expectedPeers: []string{peer2.GetWgPubKey(), peer3.GetWgPubKey()},
|
||||||
|
expectedSerial: 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5th case - an update with all peers to be removed
|
||||||
|
case5 := testCase{
|
||||||
|
idx: 3,
|
||||||
|
networkMap: &mgmtProto.NetworkMap{
|
||||||
|
Serial: 5,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{},
|
||||||
|
RemotePeersIsEmpty: true,
|
||||||
|
},
|
||||||
|
expectedLen: 0,
|
||||||
|
expectedPeers: nil,
|
||||||
|
expectedSerial: 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range []testCase{case1, case2, case3, case4, case5} {
|
||||||
|
err = engine.updateNetworkMap(c.networkMap)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(engine.peerConns) != c.expectedLen {
|
||||||
|
t.Errorf("case %d expecting Engine.peerConns to be of size %d, got %d", c.idx, c.expectedLen, len(engine.peerConns))
|
||||||
|
}
|
||||||
|
|
||||||
|
if engine.networkSerial != c.expectedSerial {
|
||||||
|
t.Errorf("case %d expecting Engine.networkSerial to be equal to %d, actual %d", c.idx, c.expectedSerial, engine.networkSerial)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range c.expectedPeers {
|
||||||
|
if _, ok := engine.peerConns[p]; !ok {
|
||||||
|
t.Errorf("case %d expecting Engine.peerConns to contain peer %s", c.idx, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_Sync(t *testing.T) {
|
||||||
|
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// feed updates to Engine via mocked Management client
|
||||||
|
updates := make(chan *mgmtProto.SyncResponse)
|
||||||
|
defer close(updates)
|
||||||
|
syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||||
|
|
||||||
|
for msg := range updates {
|
||||||
|
err := msgHandler(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := NewEngine(&signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{
|
||||||
|
WgIfaceName: "utun100",
|
||||||
|
WgAddr: "100.64.0.1/24",
|
||||||
|
WgPrivateKey: key,
|
||||||
|
WgPort: 33100,
|
||||||
|
}, cancel, ctx)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := engine.Stop()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = engine.Start()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
peer1 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.10/24"},
|
||||||
|
}
|
||||||
|
peer2 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.11/24"},
|
||||||
|
}
|
||||||
|
peer3 := &mgmtProto.RemotePeerConfig{
|
||||||
|
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
|
||||||
|
AllowedIps: []string{"100.64.0.12/24"},
|
||||||
|
}
|
||||||
|
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
|
||||||
|
updates <- &mgmtProto.SyncResponse{
|
||||||
|
NetworkMap: &mgmtProto.NetworkMap{
|
||||||
|
Serial: 10,
|
||||||
|
PeerConfig: nil,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||||
|
RemotePeersIsEmpty: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.After(time.Second * 2)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatalf("timeout while waiting for test to finish")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(engine.GetPeers()) == 3 && engine.networkSerial == 10 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestEngine_MultiplePeers(t *testing.T) {
|
func TestEngine_MultiplePeers(t *testing.T) {
|
||||||
|
|
||||||
//log.SetLevel(log.DebugLevel)
|
//log.SetLevel(log.DebugLevel)
|
||||||
@ -58,23 +284,14 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
sport := 10010
|
sport := 10010
|
||||||
signalServer, err := startSignal(sport)
|
sigServer, err := startSignal(sport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer signalServer.Stop()
|
defer sigServer.Stop()
|
||||||
mport := 33081
|
mport := 33081
|
||||||
mgmtServer, err := startManagement(mport, &server.Config{
|
mgmtServer, err := startManagement(mport, dir)
|
||||||
Stuns: []*server.Host{},
|
|
||||||
TURNConfig: &server.TURNConfig{},
|
|
||||||
Signal: &server.Host{
|
|
||||||
Proto: "http",
|
|
||||||
URI: "localhost:10000",
|
|
||||||
},
|
|
||||||
Datadir: dir,
|
|
||||||
HttpConfig: nil,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@ -201,7 +418,18 @@ func startSignal(port int) (*grpc.Server, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startManagement(port int, config *server.Config) (*grpc.Server, error) {
|
func startManagement(port int, dataDir string) (*grpc.Server, error) {
|
||||||
|
|
||||||
|
config := &server.Config{
|
||||||
|
Stuns: []*server.Host{},
|
||||||
|
TURNConfig: &server.TURNConfig{},
|
||||||
|
Signal: &server.Host{
|
||||||
|
Proto: "http",
|
||||||
|
URI: "localhost:10000",
|
||||||
|
},
|
||||||
|
Datadir: dataDir,
|
||||||
|
HttpConfig: nil,
|
||||||
|
}
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
|
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -419,3 +419,7 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) GetKey() string {
|
||||||
|
return conn.config.Key
|
||||||
|
}
|
||||||
|
@ -1,253 +1,15 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/wiretrustee/wiretrustee/client/system"
|
|
||||||
"github.com/wiretrustee/wiretrustee/encryption"
|
|
||||||
"github.com/wiretrustee/wiretrustee/management/proto"
|
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/connectivity"
|
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
"io"
|
"io"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client interface {
|
||||||
key wgtypes.Key
|
io.Closer
|
||||||
realClient proto.ManagementServiceClient
|
Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
ctx context.Context
|
GetServerPublicKey() (*wgtypes.Key, error)
|
||||||
conn *grpc.ClientConn
|
Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error)
|
||||||
}
|
Login(serverKey wgtypes.Key) (*proto.LoginResponse, error)
|
||||||
|
|
||||||
// NewClient creates a new client to Management service
|
|
||||||
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*Client, error) {
|
|
||||||
|
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
|
||||||
|
|
||||||
if tlsEnabled {
|
|
||||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
mgmCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
conn, err := grpc.DialContext(
|
|
||||||
mgmCtx,
|
|
||||||
addr,
|
|
||||||
transportOption,
|
|
||||||
grpc.WithBlock(),
|
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
|
||||||
Time: 15 * time.Second,
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
}))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed creating connection to Management Service %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
realClient := proto.NewManagementServiceClient(conn)
|
|
||||||
|
|
||||||
return &Client{
|
|
||||||
key: ourPrivateKey,
|
|
||||||
realClient: realClient,
|
|
||||||
ctx: ctx,
|
|
||||||
conn: conn,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes connection to the Management Service
|
|
||||||
func (c *Client) Close() error {
|
|
||||||
return c.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
//defaultBackoff is a basic backoff mechanism for general issues
|
|
||||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
|
||||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
|
||||||
InitialInterval: 800 * time.Millisecond,
|
|
||||||
RandomizationFactor: backoff.DefaultRandomizationFactor,
|
|
||||||
Multiplier: backoff.DefaultMultiplier,
|
|
||||||
MaxInterval: 10 * time.Second,
|
|
||||||
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
|
|
||||||
Stop: backoff.Stop,
|
|
||||||
Clock: backoff.SystemClock,
|
|
||||||
}, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ready indicates whether the client is okay and ready to be used
|
|
||||||
// for now it just checks whether gRPC connection to the service is ready
|
|
||||||
func (c *Client) ready() bool {
|
|
||||||
return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
|
||||||
// Blocking request. The result will be sent via msgHandler callback function
|
|
||||||
func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
|
|
||||||
|
|
||||||
var backOff = defaultBackoff(c.ctx)
|
|
||||||
|
|
||||||
operation := func() error {
|
|
||||||
|
|
||||||
log.Debugf("management connection state %v", c.conn.GetState())
|
|
||||||
|
|
||||||
if !c.ready() {
|
|
||||||
return fmt.Errorf("no connection to management")
|
|
||||||
}
|
|
||||||
|
|
||||||
// todo we already have it since we did the Login, maybe cache it locally?
|
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed getting Management Service public key: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := c.connectToStream(*serverPubKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to open Management Service stream: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("connected to the Management Service stream")
|
|
||||||
|
|
||||||
// blocking until error
|
|
||||||
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
|
|
||||||
if err != nil {
|
|
||||||
backOff.Reset()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := backoff.Retry(operation, backOff)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("exiting Management Service connection retry loop due to unrecoverable error: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
|
||||||
req := &proto.SyncRequest{}
|
|
||||||
|
|
||||||
myPrivateKey := c.key
|
|
||||||
myPublicKey := myPrivateKey.PublicKey()
|
|
||||||
|
|
||||||
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed encrypting message: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
|
|
||||||
return c.realClient.Sync(c.ctx, syncReq)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
|
|
||||||
for {
|
|
||||||
update, err := stream.Recv()
|
|
||||||
if err == io.EOF {
|
|
||||||
log.Errorf("Management stream has been closed by server: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("disconnected from Management Service sync stream: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("got an update message from Management Service")
|
|
||||||
decryptedResp := &proto.SyncResponse{}
|
|
||||||
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed decrypting update message from Management Service: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = msgHandler(decryptedResp)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server)
|
|
||||||
func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) {
|
|
||||||
if !c.ready() {
|
|
||||||
return nil, fmt.Errorf("no connection to management")
|
|
||||||
}
|
|
||||||
|
|
||||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
|
|
||||||
defer cancel()
|
|
||||||
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
serverKey, err := wgtypes.ParseKey(resp.Key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &serverKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
|
||||||
if !c.ready() {
|
|
||||||
return nil, fmt.Errorf("no connection to management")
|
|
||||||
}
|
|
||||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to encrypt message: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
|
|
||||||
defer cancel()
|
|
||||||
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
|
|
||||||
WgPubKey: c.key.PublicKey().String(),
|
|
||||||
Body: loginReq,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
loginResp := &proto.LoginResponse{}
|
|
||||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to decrypt registration message: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return loginResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
|
||||||
// Takes care of encrypting and decrypting messages.
|
|
||||||
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
|
||||||
func (c *Client) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) {
|
|
||||||
gi := system.GetInfo()
|
|
||||||
meta := &proto.PeerSystemMeta{
|
|
||||||
Hostname: gi.Hostname,
|
|
||||||
GoOS: gi.GoOS,
|
|
||||||
OS: gi.OS,
|
|
||||||
Core: gi.OSVersion,
|
|
||||||
Platform: gi.Platform,
|
|
||||||
Kernel: gi.Kernel,
|
|
||||||
WiretrusteeVersion: "",
|
|
||||||
}
|
|
||||||
log.Debugf("detected system %v", meta)
|
|
||||||
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: meta})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
|
||||||
func (c *Client) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) {
|
|
||||||
return c.login(serverKey, &proto.LoginRequest{})
|
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var tested *Client
|
var tested *GrpcClient
|
||||||
var serverAddr string
|
var serverAddr string
|
||||||
|
|
||||||
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||||
|
253
management/client/grpc.go
Normal file
253
management/client/grpc.go
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/wiretrustee/wiretrustee/client/system"
|
||||||
|
"github.com/wiretrustee/wiretrustee/encryption"
|
||||||
|
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GrpcClient struct {
|
||||||
|
key wgtypes.Key
|
||||||
|
realClient proto.ManagementServiceClient
|
||||||
|
ctx context.Context
|
||||||
|
conn *grpc.ClientConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new client to Management service
|
||||||
|
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
||||||
|
|
||||||
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
|
|
||||||
|
if tlsEnabled {
|
||||||
|
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
mgmCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
conn, err := grpc.DialContext(
|
||||||
|
mgmCtx,
|
||||||
|
addr,
|
||||||
|
transportOption,
|
||||||
|
grpc.WithBlock(),
|
||||||
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
|
Time: 15 * time.Second,
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed creating connection to Management Service %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
realClient := proto.NewManagementServiceClient(conn)
|
||||||
|
|
||||||
|
return &GrpcClient{
|
||||||
|
key: ourPrivateKey,
|
||||||
|
realClient: realClient,
|
||||||
|
ctx: ctx,
|
||||||
|
conn: conn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes connection to the Management Service
|
||||||
|
func (c *GrpcClient) Close() error {
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
//defaultBackoff is a basic backoff mechanism for general issues
|
||||||
|
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||||
|
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
|
InitialInterval: 800 * time.Millisecond,
|
||||||
|
RandomizationFactor: backoff.DefaultRandomizationFactor,
|
||||||
|
Multiplier: backoff.DefaultMultiplier,
|
||||||
|
MaxInterval: 10 * time.Second,
|
||||||
|
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
|
||||||
|
Stop: backoff.Stop,
|
||||||
|
Clock: backoff.SystemClock,
|
||||||
|
}, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ready indicates whether the client is okay and ready to be used
|
||||||
|
// for now it just checks whether gRPC connection to the service is ready
|
||||||
|
func (c *GrpcClient) ready() bool {
|
||||||
|
return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
||||||
|
// Blocking request. The result will be sent via msgHandler callback function
|
||||||
|
func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
|
|
||||||
|
var backOff = defaultBackoff(c.ctx)
|
||||||
|
|
||||||
|
operation := func() error {
|
||||||
|
|
||||||
|
log.Debugf("management connection state %v", c.conn.GetState())
|
||||||
|
|
||||||
|
if !c.ready() {
|
||||||
|
return fmt.Errorf("no connection to management")
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo we already have it since we did the Login, maybe cache it locally?
|
||||||
|
serverPubKey, err := c.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed getting Management Service public key: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := c.connectToStream(*serverPubKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to open Management Service stream: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("connected to the Management Service stream")
|
||||||
|
|
||||||
|
// blocking until error
|
||||||
|
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
|
||||||
|
if err != nil {
|
||||||
|
backOff.Reset()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := backoff.Retry(operation, backOff)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("exiting Management Service connection retry loop due to unrecoverable error: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
|
||||||
|
req := &proto.SyncRequest{}
|
||||||
|
|
||||||
|
myPrivateKey := c.key
|
||||||
|
myPublicKey := myPrivateKey.PublicKey()
|
||||||
|
|
||||||
|
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed encrypting message: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
|
||||||
|
return c.realClient.Sync(c.ctx, syncReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
|
for {
|
||||||
|
update, err := stream.Recv()
|
||||||
|
if err == io.EOF {
|
||||||
|
log.Errorf("Management stream has been closed by server: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("disconnected from Management Service sync stream: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("got an update message from Management Service")
|
||||||
|
decryptedResp := &proto.SyncResponse{}
|
||||||
|
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed decrypting update message from Management Service: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = msgHandler(decryptedResp)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server)
|
||||||
|
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||||
|
if !c.ready() {
|
||||||
|
return nil, fmt.Errorf("no connection to management")
|
||||||
|
}
|
||||||
|
|
||||||
|
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
|
||||||
|
defer cancel()
|
||||||
|
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
serverKey, err := wgtypes.ParseKey(resp.Key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &serverKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||||
|
if !c.ready() {
|
||||||
|
return nil, fmt.Errorf("no connection to management")
|
||||||
|
}
|
||||||
|
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to encrypt message: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
|
||||||
|
defer cancel()
|
||||||
|
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
|
||||||
|
WgPubKey: c.key.PublicKey().String(),
|
||||||
|
Body: loginReq,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
loginResp := &proto.LoginResponse{}
|
||||||
|
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to decrypt registration message: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return loginResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
||||||
|
// Takes care of encrypting and decrypting messages.
|
||||||
|
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
||||||
|
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) {
|
||||||
|
gi := system.GetInfo()
|
||||||
|
meta := &proto.PeerSystemMeta{
|
||||||
|
Hostname: gi.Hostname,
|
||||||
|
GoOS: gi.GoOS,
|
||||||
|
OS: gi.OS,
|
||||||
|
Core: gi.OSVersion,
|
||||||
|
Platform: gi.Platform,
|
||||||
|
Kernel: gi.Kernel,
|
||||||
|
WiretrusteeVersion: "",
|
||||||
|
}
|
||||||
|
log.Debugf("detected system %v", meta)
|
||||||
|
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: meta})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
||||||
|
func (c *GrpcClient) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) {
|
||||||
|
return c.login(serverKey, &proto.LoginRequest{})
|
||||||
|
}
|
49
management/client/mock.go
Normal file
49
management/client/mock.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/wiretrustee/wiretrustee/management/proto"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockClient struct {
|
||||||
|
CloseFunc func() error
|
||||||
|
SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
|
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
||||||
|
RegisterFunc func(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error)
|
||||||
|
LoginFunc func(serverKey wgtypes.Key) (*proto.LoginResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClient) Close() error {
|
||||||
|
if m.CloseFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.CloseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
|
if m.SyncFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.SyncFunc(msgHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||||
|
if m.GetServerPublicKeyFunc == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return m.GetServerPublicKeyFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) {
|
||||||
|
if m.RegisterFunc == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return m.RegisterFunc(serverKey, setupKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClient) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) {
|
||||||
|
if m.LoginFunc == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return m.LoginFunc(serverKey)
|
||||||
|
}
|
@ -1,26 +1,11 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/wiretrustee/wiretrustee/encryption"
|
|
||||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/connectivity"
|
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
"google.golang.org/grpc/metadata"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
|
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
|
||||||
@ -31,317 +16,15 @@ type Status string
|
|||||||
const StreamConnected Status = "Connected"
|
const StreamConnected Status = "Connected"
|
||||||
const StreamDisconnected Status = "Disconnected"
|
const StreamDisconnected Status = "Disconnected"
|
||||||
|
|
||||||
// Client Wraps the Signal Exchange Service gRpc client
|
type Client interface {
|
||||||
type Client struct {
|
io.Closer
|
||||||
key wgtypes.Key
|
StreamConnected() bool
|
||||||
realClient proto.SignalExchangeClient
|
GetStatus() Status
|
||||||
signalConn *grpc.ClientConn
|
Receive(msgHandler func(msg *proto.Message) error) error
|
||||||
ctx context.Context
|
Ready() bool
|
||||||
stream proto.SignalExchange_ConnectStreamClient
|
WaitStreamConnected()
|
||||||
// connectedCh used to notify goroutines waiting for the connection to the Signal stream
|
SendToStream(msg *proto.EncryptedMessage) error
|
||||||
connectedCh chan struct{}
|
Send(msg *proto.Message) error
|
||||||
mux sync.Mutex
|
|
||||||
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
|
|
||||||
status Status
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) StreamConnected() bool {
|
|
||||||
return c.status == StreamConnected
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) GetStatus() Status {
|
|
||||||
return c.status
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close Closes underlying connections to the Signal Exchange
|
|
||||||
func (c *Client) Close() error {
|
|
||||||
return c.signalConn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient creates a new Signal client
|
|
||||||
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*Client, error) {
|
|
||||||
|
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
|
||||||
|
|
||||||
if tlsEnabled {
|
|
||||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
sigCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
conn, err := grpc.DialContext(
|
|
||||||
sigCtx,
|
|
||||||
addr,
|
|
||||||
transportOption,
|
|
||||||
grpc.WithBlock(),
|
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
|
||||||
Time: 15 * time.Second,
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
}))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to connect to the signalling server %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Client{
|
|
||||||
realClient: proto.NewSignalExchangeClient(conn),
|
|
||||||
ctx: ctx,
|
|
||||||
signalConn: conn,
|
|
||||||
key: key,
|
|
||||||
mux: sync.Mutex{},
|
|
||||||
status: StreamDisconnected,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//defaultBackoff is a basic backoff mechanism for general issues
|
|
||||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
|
||||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
|
||||||
InitialInterval: 800 * time.Millisecond,
|
|
||||||
RandomizationFactor: backoff.DefaultRandomizationFactor,
|
|
||||||
Multiplier: backoff.DefaultMultiplier,
|
|
||||||
MaxInterval: 10 * time.Second,
|
|
||||||
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
|
|
||||||
Stop: backoff.Stop,
|
|
||||||
Clock: backoff.SystemClock,
|
|
||||||
}, ctx)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Receive Connects to the Signal Exchange message stream and starts receiving messages.
|
|
||||||
// The messages will be handled by msgHandler function provided.
|
|
||||||
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
|
|
||||||
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
|
|
||||||
func (c *Client) Receive(msgHandler func(msg *proto.Message) error) error {
|
|
||||||
|
|
||||||
var backOff = defaultBackoff(c.ctx)
|
|
||||||
|
|
||||||
operation := func() error {
|
|
||||||
|
|
||||||
c.notifyStreamDisconnected()
|
|
||||||
|
|
||||||
log.Debugf("signal connection state %v", c.signalConn.GetState())
|
|
||||||
if !c.Ready() {
|
|
||||||
return fmt.Errorf("no connection to signal")
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect to Signal stream identifying ourselves with a public Wireguard key
|
|
||||||
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
|
||||||
stream, err := c.connect(c.key.PublicKey().String())
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.notifyStreamConnected()
|
|
||||||
|
|
||||||
log.Infof("connected to the Signal Service stream")
|
|
||||||
|
|
||||||
// start receiving messages from the Signal stream (from other peers through signal)
|
|
||||||
err = c.receive(stream, msgHandler)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
|
||||||
backOff.Reset()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := backoff.Retry(operation, backOff)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (c *Client) notifyStreamDisconnected() {
|
|
||||||
c.mux.Lock()
|
|
||||||
defer c.mux.Unlock()
|
|
||||||
c.status = StreamDisconnected
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) notifyStreamConnected() {
|
|
||||||
c.mux.Lock()
|
|
||||||
defer c.mux.Unlock()
|
|
||||||
c.status = StreamConnected
|
|
||||||
if c.connectedCh != nil {
|
|
||||||
// there are goroutines waiting on this channel -> release them
|
|
||||||
close(c.connectedCh)
|
|
||||||
c.connectedCh = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) getStreamStatusChan() <-chan struct{} {
|
|
||||||
c.mux.Lock()
|
|
||||||
defer c.mux.Unlock()
|
|
||||||
if c.connectedCh == nil {
|
|
||||||
c.connectedCh = make(chan struct{})
|
|
||||||
}
|
|
||||||
return c.connectedCh
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
|
||||||
c.stream = nil
|
|
||||||
|
|
||||||
// add key fingerprint to the request header to be identified on the server side
|
|
||||||
md := metadata.New(map[string]string{proto.HeaderId: key})
|
|
||||||
ctx := metadata.NewOutgoingContext(c.ctx, md)
|
|
||||||
|
|
||||||
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
|
|
||||||
|
|
||||||
c.stream = stream
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// blocks
|
|
||||||
header, err := c.stream.Header()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
registered := header.Get(proto.HeaderRegistered)
|
|
||||||
if len(registered) == 0 {
|
|
||||||
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
|
|
||||||
}
|
|
||||||
|
|
||||||
return stream, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ready indicates whether the client is okay and Ready to be used
|
|
||||||
// for now it just checks whether gRPC connection to the service is in state Ready
|
|
||||||
func (c *Client) Ready() bool {
|
|
||||||
return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitStreamConnected waits until the client is connected to the Signal stream
|
|
||||||
func (c *Client) WaitStreamConnected() {
|
|
||||||
|
|
||||||
if c.status == StreamConnected {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ch := c.getStreamStatusChan()
|
|
||||||
select {
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
case <-ch:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
|
|
||||||
// The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
|
|
||||||
// Client.connWg can be used to wait
|
|
||||||
func (c *Client) SendToStream(msg *proto.EncryptedMessage) error {
|
|
||||||
if !c.Ready() {
|
|
||||||
return fmt.Errorf("no connection to signal")
|
|
||||||
}
|
|
||||||
if c.stream == nil {
|
|
||||||
return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.stream.Send(msg)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key
|
|
||||||
func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) {
|
|
||||||
remoteKey, err := wgtypes.ParseKey(msg.GetKey())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
body := &proto.Body{}
|
|
||||||
err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &proto.Message{
|
|
||||||
Key: msg.Key,
|
|
||||||
RemoteKey: msg.RemoteKey,
|
|
||||||
Body: body,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
|
|
||||||
func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
|
|
||||||
|
|
||||||
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &proto.EncryptedMessage{
|
|
||||||
Key: msg.GetKey(),
|
|
||||||
RemoteKey: msg.GetRemoteKey(),
|
|
||||||
Body: encryptedBody,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send sends a message to the remote Peer through the Signal Exchange.
|
|
||||||
func (c *Client) Send(msg *proto.Message) error {
|
|
||||||
|
|
||||||
if !c.Ready() {
|
|
||||||
return fmt.Errorf("no connection to signal")
|
|
||||||
}
|
|
||||||
|
|
||||||
encryptedMessage, err := c.encryptMessage(msg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
_, err = c.realClient.Send(ctx, encryptedMessage)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// receive receives messages from other peers coming through the Signal Exchange
|
|
||||||
func (c *Client) receive(stream proto.SignalExchange_ConnectStreamClient,
|
|
||||||
msgHandler func(msg *proto.Message) error) error {
|
|
||||||
|
|
||||||
for {
|
|
||||||
msg, err := stream.Recv()
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
|
||||||
log.Warnf("stream canceled (usually indicates shutdown)")
|
|
||||||
return err
|
|
||||||
} else if s.Code() == codes.Unavailable {
|
|
||||||
log.Warnf("Signal Service is unavailable")
|
|
||||||
return err
|
|
||||||
} else if err == io.EOF {
|
|
||||||
log.Warnf("Signal Service stream closed by server")
|
|
||||||
return err
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key)
|
|
||||||
|
|
||||||
decryptedMessage, err := c.decryptMessage(msg)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
err = msgHandler(decryptedMessage)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
|
||||||
//todo send something??
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnMarshalCredential parses the credentials from the message and returns a Credential instance
|
// UnMarshalCredential parses the credentials from the message and returns a Credential instance
|
||||||
@ -369,7 +52,7 @@ func MarshalCredential(myKey wgtypes.Key, remoteKey wgtypes.Key, credential *Cre
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Credential is an instance of a Client's Credential
|
// Credential is an instance of a GrpcClient's Credential
|
||||||
type Credential struct {
|
type Credential struct {
|
||||||
UFrag string
|
UFrag string
|
||||||
Pwd string
|
Pwd string
|
||||||
|
@ -17,7 +17,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Client", func() {
|
var _ = Describe("GrpcClient", func() {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
addr string
|
addr string
|
||||||
@ -160,7 +160,7 @@ var _ = Describe("Client", func() {
|
|||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
func createSignalClient(addr string, key wgtypes.Key) *Client {
|
func createSignalClient(addr string, key wgtypes.Key) *GrpcClient {
|
||||||
var sigTLSEnabled = false
|
var sigTLSEnabled = false
|
||||||
client, err := NewClient(context.Background(), addr, key, sigTLSEnabled)
|
client, err := NewClient(context.Background(), addr, key, sigTLSEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
336
signal/client/grpc.go
Normal file
336
signal/client/grpc.go
Normal file
@ -0,0 +1,336 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/wiretrustee/wiretrustee/encryption"
|
||||||
|
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GrpcClient Wraps the Signal Exchange Service gRpc client
|
||||||
|
type GrpcClient struct {
|
||||||
|
key wgtypes.Key
|
||||||
|
realClient proto.SignalExchangeClient
|
||||||
|
signalConn *grpc.ClientConn
|
||||||
|
ctx context.Context
|
||||||
|
stream proto.SignalExchange_ConnectStreamClient
|
||||||
|
// connectedCh used to notify goroutines waiting for the connection to the Signal stream
|
||||||
|
connectedCh chan struct{}
|
||||||
|
mux sync.Mutex
|
||||||
|
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
|
||||||
|
status Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GrpcClient) StreamConnected() bool {
|
||||||
|
return c.status == StreamConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GrpcClient) GetStatus() Status {
|
||||||
|
return c.status
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close Closes underlying connections to the Signal Exchange
|
||||||
|
func (c *GrpcClient) Close() error {
|
||||||
|
return c.signalConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new Signal client
|
||||||
|
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
||||||
|
|
||||||
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
|
|
||||||
|
if tlsEnabled {
|
||||||
|
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
sigCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
conn, err := grpc.DialContext(
|
||||||
|
sigCtx,
|
||||||
|
addr,
|
||||||
|
transportOption,
|
||||||
|
grpc.WithBlock(),
|
||||||
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
|
Time: 15 * time.Second,
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to connect to the signalling server %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GrpcClient{
|
||||||
|
realClient: proto.NewSignalExchangeClient(conn),
|
||||||
|
ctx: ctx,
|
||||||
|
signalConn: conn,
|
||||||
|
key: key,
|
||||||
|
mux: sync.Mutex{},
|
||||||
|
status: StreamDisconnected,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//defaultBackoff is a basic backoff mechanism for general issues
|
||||||
|
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||||
|
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
|
InitialInterval: 800 * time.Millisecond,
|
||||||
|
RandomizationFactor: backoff.DefaultRandomizationFactor,
|
||||||
|
Multiplier: backoff.DefaultMultiplier,
|
||||||
|
MaxInterval: 10 * time.Second,
|
||||||
|
MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
|
||||||
|
Stop: backoff.Stop,
|
||||||
|
Clock: backoff.SystemClock,
|
||||||
|
}, ctx)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive Connects to the Signal Exchange message stream and starts receiving messages.
|
||||||
|
// The messages will be handled by msgHandler function provided.
|
||||||
|
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
|
||||||
|
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
|
||||||
|
func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
||||||
|
|
||||||
|
var backOff = defaultBackoff(c.ctx)
|
||||||
|
|
||||||
|
operation := func() error {
|
||||||
|
|
||||||
|
c.notifyStreamDisconnected()
|
||||||
|
|
||||||
|
log.Debugf("signal connection state %v", c.signalConn.GetState())
|
||||||
|
if !c.Ready() {
|
||||||
|
return fmt.Errorf("no connection to signal")
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect to Signal stream identifying ourselves with a public Wireguard key
|
||||||
|
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
||||||
|
stream, err := c.connect(c.key.PublicKey().String())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.notifyStreamConnected()
|
||||||
|
|
||||||
|
log.Infof("connected to the Signal Service stream")
|
||||||
|
|
||||||
|
// start receiving messages from the Signal stream (from other peers through signal)
|
||||||
|
err = c.receive(stream, msgHandler)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
||||||
|
backOff.Reset()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := backoff.Retry(operation, backOff)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *GrpcClient) notifyStreamDisconnected() {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
c.status = StreamDisconnected
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GrpcClient) notifyStreamConnected() {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
c.status = StreamConnected
|
||||||
|
if c.connectedCh != nil {
|
||||||
|
// there are goroutines waiting on this channel -> release them
|
||||||
|
close(c.connectedCh)
|
||||||
|
c.connectedCh = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
if c.connectedCh == nil {
|
||||||
|
c.connectedCh = make(chan struct{})
|
||||||
|
}
|
||||||
|
return c.connectedCh
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
||||||
|
c.stream = nil
|
||||||
|
|
||||||
|
// add key fingerprint to the request header to be identified on the server side
|
||||||
|
md := metadata.New(map[string]string{proto.HeaderId: key})
|
||||||
|
ctx := metadata.NewOutgoingContext(c.ctx, md)
|
||||||
|
|
||||||
|
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
|
||||||
|
|
||||||
|
c.stream = stream
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// blocks
|
||||||
|
header, err := c.stream.Header()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
registered := header.Get(proto.HeaderRegistered)
|
||||||
|
if len(registered) == 0 {
|
||||||
|
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ready indicates whether the client is okay and Ready to be used
|
||||||
|
// for now it just checks whether gRPC connection to the service is in state Ready
|
||||||
|
func (c *GrpcClient) Ready() bool {
|
||||||
|
return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitStreamConnected waits until the client is connected to the Signal stream
|
||||||
|
func (c *GrpcClient) WaitStreamConnected() {
|
||||||
|
|
||||||
|
if c.status == StreamConnected {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := c.getStreamStatusChan()
|
||||||
|
select {
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
case <-ch:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
|
||||||
|
// The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
|
||||||
|
// GrpcClient.connWg can be used to wait
|
||||||
|
func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error {
|
||||||
|
if !c.Ready() {
|
||||||
|
return fmt.Errorf("no connection to signal")
|
||||||
|
}
|
||||||
|
if c.stream == nil {
|
||||||
|
return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call GrpcClient.Receive before sending messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.stream.Send(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||||
|
func (c *GrpcClient) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) {
|
||||||
|
remoteKey, err := wgtypes.ParseKey(msg.GetKey())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body := &proto.Body{}
|
||||||
|
err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.Message{
|
||||||
|
Key: msg.Key,
|
||||||
|
RemoteKey: msg.RemoteKey,
|
||||||
|
Body: body,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||||
|
func (c *GrpcClient) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
|
||||||
|
|
||||||
|
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.EncryptedMessage{
|
||||||
|
Key: msg.GetKey(),
|
||||||
|
RemoteKey: msg.GetRemoteKey(),
|
||||||
|
Body: encryptedBody,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends a message to the remote Peer through the Signal Exchange.
|
||||||
|
func (c *GrpcClient) Send(msg *proto.Message) error {
|
||||||
|
|
||||||
|
if !c.Ready() {
|
||||||
|
return fmt.Errorf("no connection to signal")
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedMessage, err := c.encryptMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, err = c.realClient.Send(ctx, encryptedMessage)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// receive receives messages from other peers coming through the Signal Exchange
|
||||||
|
func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
|
||||||
|
msgHandler func(msg *proto.Message) error) error {
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg, err := stream.Recv()
|
||||||
|
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||||
|
log.Warnf("stream canceled (usually indicates shutdown)")
|
||||||
|
return err
|
||||||
|
} else if s.Code() == codes.Unavailable {
|
||||||
|
log.Warnf("Signal Service is unavailable")
|
||||||
|
return err
|
||||||
|
} else if err == io.EOF {
|
||||||
|
log.Warnf("Signal Service stream closed by server")
|
||||||
|
return err
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key)
|
||||||
|
|
||||||
|
decryptedMessage, err := c.decryptMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = msgHandler(decryptedMessage)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
||||||
|
//todo send something??
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
72
signal/client/mock.go
Normal file
72
signal/client/mock.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockClient struct {
|
||||||
|
CloseFunc func() error
|
||||||
|
GetStatusFunc func() Status
|
||||||
|
StreamConnectedFunc func() bool
|
||||||
|
ReadyFunc func() bool
|
||||||
|
WaitStreamConnectedFunc func()
|
||||||
|
ReceiveFunc func(msgHandler func(msg *proto.Message) error) error
|
||||||
|
SendToStreamFunc func(msg *proto.EncryptedMessage) error
|
||||||
|
SendFunc func(msg *proto.Message) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *MockClient) Close() error {
|
||||||
|
if sm.CloseFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return sm.CloseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *MockClient) GetStatus() Status {
|
||||||
|
if sm.GetStatusFunc == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return sm.GetStatusFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *MockClient) StreamConnected() bool {
|
||||||
|
if sm.StreamConnectedFunc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return sm.StreamConnectedFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *MockClient) Ready() bool {
|
||||||
|
if sm.ReadyFunc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return sm.ReadyFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *MockClient) WaitStreamConnected() {
|
||||||
|
if sm.WaitStreamConnectedFunc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sm.WaitStreamConnectedFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *MockClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
||||||
|
if sm.ReceiveFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return sm.ReceiveFunc(msgHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *MockClient) SendToStream(msg *proto.EncryptedMessage) error {
|
||||||
|
if sm.SendToStreamFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return sm.SendToStreamFunc(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *MockClient) Send(msg *proto.Message) error {
|
||||||
|
if sm.SendFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return sm.SendFunc(msg)
|
||||||
|
}
|
16
util/common.go
Normal file
16
util/common.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
// SliceDiff returns the elements in slice `x` that are not in slice `y`
|
||||||
|
func SliceDiff(x, y []string) []string {
|
||||||
|
mapY := make(map[string]struct{}, len(y))
|
||||||
|
for _, val := range y {
|
||||||
|
mapY[val] = struct{}{}
|
||||||
|
}
|
||||||
|
var diff []string
|
||||||
|
for _, val := range x {
|
||||||
|
if _, found := mapY[val]; !found {
|
||||||
|
diff = append(diff, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return diff
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user