refactor signal client sync func (#147)

* refactor: move goroutine that runs Signal Client Receive to the engine for better control

* chore: fix comments typo

* test: fix golint

* chore: comments update

* chore: consider connection state=READY in signal and management clients

* chore: fix typos

* test: fix signal ping-pong test

* chore: add wait condition to signal client

* refactor: add stream status to the Signal client

* refactor: defer mutex unlock
This commit is contained in:
Mikhail Bragin 2021-11-06 15:00:13 +01:00 committed by GitHub
parent 4d34fb4e64
commit ed1e4dfc51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 236 additions and 124 deletions

View File

@ -4,7 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
ice "github.com/pion/ice/v2" "github.com/pion/ice/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/iface" "github.com/wiretrustee/wiretrustee/iface"
mgm "github.com/wiretrustee/wiretrustee/management/client" mgm "github.com/wiretrustee/wiretrustee/management/client"
@ -142,7 +142,7 @@ func (e *Engine) initializePeer(peer Peer) {
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier, Multiplier: backoff.DefaultMultiplier,
MaxInterval: 5 * time.Second, MaxInterval: 5 * time.Second,
MaxElapsedTime: time.Duration(0), //never stop MaxElapsedTime: 0, //never stop
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, e.ctx) }, e.ctx)
@ -157,8 +157,7 @@ func (e *Engine) initializePeer(peer Peer) {
} }
if err != nil { if err != nil {
log.Warnln(err) log.Infof("retrying connection because of error: %s", err.Error())
log.Debugf("retrying connection because of error: %s", err.Error())
return err return err
} }
return nil return nil
@ -332,6 +331,8 @@ func (e *Engine) receiveManagementEvents() {
return nil return nil
}) })
if err != nil { if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
e.cancel() e.cancel()
return return
} }
@ -414,68 +415,77 @@ func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error {
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() { func (e *Engine) receiveSignalEvents() {
// connect to a stream of messages coming from the signal server
e.signal.Receive(func(msg *sProto.Message) error {
e.syncMsgMux.Lock() go func() {
defer e.syncMsgMux.Unlock() // connect to a stream of messages coming from the signal server
err := e.signal.Receive(func(msg *sProto.Message) error {
conn := e.conns[msg.Key] e.syncMsgMux.Lock()
if conn == nil { defer e.syncMsgMux.Unlock()
return fmt.Errorf("wrongly addressed message %s", msg.Key)
}
if conn.Config.RemoteWgKey.String() != msg.Key { conn := e.conns[msg.Key]
return fmt.Errorf("unknown peer %s", msg.Key) if conn == nil {
} return fmt.Errorf("wrongly addressed message %s", msg.Key)
switch msg.GetBody().Type {
case sProto.Body_OFFER:
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return err
} }
err = conn.OnOffer(IceCredentials{
uFrag: remoteCred.UFrag,
pwd: remoteCred.Pwd,
})
if err != nil { if conn.Config.RemoteWgKey.String() != msg.Key {
return err return fmt.Errorf("unknown peer %s", msg.Key)
}
switch msg.GetBody().Type {
case sProto.Body_OFFER:
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return err
}
err = conn.OnOffer(IceCredentials{
uFrag: remoteCred.UFrag,
pwd: remoteCred.Pwd,
})
if err != nil {
return err
}
return nil
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return err
}
err = conn.OnAnswer(IceCredentials{
uFrag: remoteCred.UFrag,
pwd: remoteCred.Pwd,
})
if err != nil {
return err
}
case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
if err != nil {
log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err)
return err
}
err = conn.OnRemoteCandidate(candidate)
if err != nil {
log.Errorf("error handling CANDIATE from %s", msg.Key)
return err
}
} }
return nil return nil
case sProto.Body_ANSWER: })
remoteCred, err := signal.UnMarshalCredential(msg) if err != nil {
if err != nil { // happens if signal is unavailable for a long time.
return err // We want to cancel the operation of the whole client
} e.cancel()
err = conn.OnAnswer(IceCredentials{ return
uFrag: remoteCred.UFrag,
pwd: remoteCred.Pwd,
})
if err != nil {
return err
}
case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
if err != nil {
log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err)
return err
}
err = conn.OnRemoteCandidate(candidate)
if err != nil {
log.Errorf("error handling CANDIATE from %s", msg.Key)
return err
}
} }
}()
return nil e.signal.WaitStreamConnected()
})
e.signal.WaitConnected()
} }

View File

@ -3,6 +3,7 @@ package client
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/client/system" "github.com/wiretrustee/wiretrustee/client/system"
@ -10,6 +11,7 @@ import (
"github.com/wiretrustee/wiretrustee/management/proto" "github.com/wiretrustee/wiretrustee/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"io" "io"
@ -71,12 +73,18 @@ func defaultBackoff(ctx context.Context) backoff.BackOff {
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier, Multiplier: backoff.DefaultMultiplier,
MaxInterval: 10 * time.Second, MaxInterval: 10 * time.Second,
MaxElapsedTime: 30 * time.Minute, //stop after an 30 min of trying, the error will be propagated to the general retry of the client MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, ctx) }, ctx)
} }
// ready indicates whether the client is okay and ready to be used
// for now it just checks whether gRPC connection to the service is ready
func (c *Client) ready() bool {
return c.conn.GetState() == connectivity.Ready
}
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function // Blocking request. The result will be sent via msgHandler callback function
func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error { func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
@ -85,6 +93,12 @@ func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
operation := func() error { operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
if !c.ready() {
return fmt.Errorf("no connection to management")
}
// todo we already have it since we did the Login, maybe cache it locally? // todo we already have it since we did the Login, maybe cache it locally?
serverPubKey, err := c.GetServerPublicKey() serverPubKey, err := c.GetServerPublicKey()
if err != nil { if err != nil {
@ -98,7 +112,7 @@ func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
return err return err
} }
log.Infof("connected to the Management Service Stream") log.Infof("connected to the Management Service stream")
// blocking until error // blocking until error
err = c.receiveEvents(stream, *serverPubKey, msgHandler) err = c.receiveEvents(stream, *serverPubKey, msgHandler)
@ -139,7 +153,7 @@ func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, server
for { for {
update, err := stream.Recv() update, err := stream.Recv()
if err == io.EOF { if err == io.EOF {
log.Errorf("managment stream was closed: %s", err) log.Errorf("Management stream has been closed by server: %s", err)
return err return err
} }
if err != nil { if err != nil {
@ -165,6 +179,10 @@ func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, server
// GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server) // GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server)
func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) { func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
defer cancel() defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
@ -181,6 +199,9 @@ func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) {
} }
func (c *Client) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { func (c *Client) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil { if err != nil {
log.Errorf("failed to encrypt message: %s", err) log.Errorf("failed to encrypt message: %s", err)

View File

@ -11,6 +11,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -23,6 +24,12 @@ import (
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer. // A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
// Status is the status of the client
type Status string
const streamConnected Status = "streamConnected"
const streamDisconnected Status = "streamDisconnected"
// Client Wraps the Signal Exchange Service gRpc client // Client Wraps the Signal Exchange Service gRpc client
type Client struct { type Client struct {
key wgtypes.Key key wgtypes.Key
@ -30,8 +37,11 @@ type Client struct {
signalConn *grpc.ClientConn signalConn *grpc.ClientConn
ctx context.Context ctx context.Context
stream proto.SignalExchange_ConnectStreamClient stream proto.SignalExchange_ConnectStreamClient
//waiting group to notify once stream is connected // connectedCh used to notify goroutines waiting for the connection to the Signal stream
connWg *sync.WaitGroup //todo use a channel instead?? connectedCh chan struct{}
mux sync.Mutex
// streamConnected indicates whether this client is streamConnected to the Signal stream
status Status
} }
// Close Closes underlying connections to the Signal Exchange // Close Closes underlying connections to the Signal Exchange
@ -65,13 +75,13 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
return nil, err return nil, err
} }
var wg sync.WaitGroup
return &Client{ return &Client{
realClient: proto.NewSignalExchangeClient(conn), realClient: proto.NewSignalExchangeClient(conn),
ctx: ctx, ctx: ctx,
signalConn: conn, signalConn: conn,
key: key, key: key,
connWg: &wg, mux: sync.Mutex{},
status: streamDisconnected,
}, nil }, nil
} }
@ -82,7 +92,7 @@ func defaultBackoff(ctx context.Context) backoff.BackOff {
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier, Multiplier: backoff.DefaultMultiplier,
MaxInterval: 10 * time.Second, MaxInterval: 10 * time.Second,
MaxElapsedTime: 30 * time.Minute, //stop after an 30 min of trying, the error will be propagated to the general retry of the client MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, ctx) }, ctx)
@ -91,38 +101,76 @@ func defaultBackoff(ctx context.Context) backoff.BackOff {
// Receive Connects to the Signal Exchange message stream and starts receiving messages. // Receive Connects to the Signal Exchange message stream and starts receiving messages.
// The messages will be handled by msgHandler function provided. // The messages will be handled by msgHandler function provided.
// This function runs a goroutine underneath and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart) // This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
// The key is the identifier of our Peer (could be Wireguard public key) // The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
func (c *Client) Receive(msgHandler func(msg *proto.Message) error) { func (c *Client) Receive(msgHandler func(msg *proto.Message) error) error {
c.connWg.Add(1)
go func() {
var backOff = defaultBackoff(c.ctx) var backOff = defaultBackoff(c.ctx)
operation := func() error { operation := func() error {
stream, err := c.connect(c.key.PublicKey().String()) c.notifyStreamDisconnected()
if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
c.connWg.Add(1)
return err
}
err = c.receive(stream, msgHandler) log.Debugf("signal connection state %v", c.signalConn.GetState())
if err != nil { if !c.ready() {
backOff.Reset() return fmt.Errorf("no connection to signal")
return err
}
return nil
} }
err := backoff.Retry(operation, backOff) // connect to Signal stream identifying ourselves with a public Wireguard key
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
stream, err := c.connect(c.key.PublicKey().String())
if err != nil { if err != nil {
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err) log.Warnf("streamDisconnected from the Signal Exchange due to an error: %v", err)
return return err
} }
}()
c.notifyStreamConnected()
log.Infof("streamConnected to the Signal Service stream")
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler)
if err != nil {
log.Warnf("streamDisconnected from the Signal Exchange due to an error: %v", err)
backOff.Reset()
return err
}
return nil
}
err := backoff.Retry(operation, backOff)
if err != nil {
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err)
return err
}
return nil
}
func (c *Client) notifyStreamDisconnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = streamDisconnected
}
func (c *Client) notifyStreamConnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = streamConnected
if c.connectedCh != nil {
// there are goroutines waiting on this channel -> release them
close(c.connectedCh)
c.connectedCh = nil
}
}
func (c *Client) getStreamStatusChan() <-chan struct{} {
c.mux.Lock()
defer c.mux.Unlock()
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
return c.connectedCh
} }
func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) { func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
@ -147,24 +195,37 @@ func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient,
if len(registered) == 0 { if len(registered) == 0 {
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams") return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
} }
//connection established we are good to use the stream
c.connWg.Done()
log.Infof("connected to the Signal Exchange Stream")
return stream, nil return stream, nil
} }
// WaitConnected waits until the client is connected to the message stream // ready indicates whether the client is okay and ready to be used
func (c *Client) WaitConnected() { // for now it just checks whether gRPC connection to the service is in state Ready
c.connWg.Wait() func (c *Client) ready() bool {
return c.signalConn.GetState() == connectivity.Ready
}
// WaitStreamConnected waits until the client is connected to the Signal stream
func (c *Client) WaitStreamConnected() {
if c.status == streamConnected {
return
}
ch := c.getStreamStatusChan()
select {
case <-c.ctx.Done():
case <-ch:
}
} }
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
// The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
// Client.connWg can be used to wait // Client.connWg can be used to wait
func (c *Client) SendToStream(msg *proto.EncryptedMessage) error { func (c *Client) SendToStream(msg *proto.EncryptedMessage) error {
if !c.ready() {
return fmt.Errorf("no connection to signal")
}
if c.stream == nil { if c.stream == nil {
return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages") return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages")
} }
@ -221,13 +282,17 @@ func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, er
// Send sends a message to the remote Peer through the Signal Exchange. // Send sends a message to the remote Peer through the Signal Exchange.
func (c *Client) Send(msg *proto.Message) error { func (c *Client) Send(msg *proto.Message) error {
if !c.ready() {
return fmt.Errorf("no connection to signal")
}
encryptedMessage, err := c.encryptMessage(msg) encryptedMessage, err := c.encryptMessage(msg)
if err != nil { if err != nil {
return err return err
} }
_, err = c.realClient.Send(context.TODO(), encryptedMessage) _, err = c.realClient.Send(context.TODO(), encryptedMessage)
if err != nil { if err != nil {
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err) //log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
return err return err
} }
@ -244,10 +309,10 @@ func (c *Client) receive(stream proto.SignalExchange_ConnectStreamClient,
log.Warnf("stream canceled (usually indicates shutdown)") log.Warnf("stream canceled (usually indicates shutdown)")
return err return err
} else if s.Code() == codes.Unavailable { } else if s.Code() == codes.Unavailable {
log.Warnf("server has been stopped") log.Warnf("Signal Service is unavailable")
return err return err
} else if err == io.EOF { } else if err == io.EOF {
log.Warnf("stream closed by server") log.Warnf("Signal Service stream closed by server")
return err return err
} else if err != nil { } else if err != nil {
return err return err

View File

@ -36,7 +36,7 @@ var _ = Describe("Client", func() {
}) })
Describe("Exchanging messages", func() { Describe("Exchanging messages", func() {
Context("between connected peers", func() { Context("between streamConnected peers", func() {
It("should be successful", func() { It("should be successful", func() {
var msgReceived sync.WaitGroup var msgReceived sync.WaitGroup
@ -48,30 +48,42 @@ var _ = Describe("Client", func() {
// connect PeerA to Signal // connect PeerA to Signal
keyA, _ := wgtypes.GenerateKey() keyA, _ := wgtypes.GenerateKey()
clientA := createSignalClient(addr, keyA) clientA := createSignalClient(addr, keyA)
clientA.Receive(func(msg *sigProto.Message) error { go func() {
receivedOnA = msg.GetBody().GetPayload() err := clientA.Receive(func(msg *sigProto.Message) error {
msgReceived.Done() receivedOnA = msg.GetBody().GetPayload()
return nil msgReceived.Done()
}) return nil
clientA.WaitConnected() })
if err != nil {
return
}
}()
clientA.WaitStreamConnected()
// connect PeerB to Signal // connect PeerB to Signal
keyB, _ := wgtypes.GenerateKey() keyB, _ := wgtypes.GenerateKey()
clientB := createSignalClient(addr, keyB) clientB := createSignalClient(addr, keyB)
clientB.Receive(func(msg *sigProto.Message) error {
receivedOnB = msg.GetBody().GetPayload() go func() {
err := clientB.Send(&sigProto.Message{ err := clientB.Receive(func(msg *sigProto.Message) error {
Key: keyB.PublicKey().String(), receivedOnB = msg.GetBody().GetPayload()
RemoteKey: keyA.PublicKey().String(), err := clientB.Send(&sigProto.Message{
Body: &sigProto.Body{Payload: "pong"}, Key: keyB.PublicKey().String(),
RemoteKey: keyA.PublicKey().String(),
Body: &sigProto.Body{Payload: "pong"},
})
if err != nil {
Fail("failed sending a message to PeerA")
}
msgReceived.Done()
return nil
}) })
if err != nil { if err != nil {
Fail("failed sending a message to PeerA") return
} }
msgReceived.Done() }()
return nil
}) clientB.WaitStreamConnected()
clientB.WaitConnected()
// PeerA initiates ping-pong // PeerA initiates ping-pong
err := clientA.Send(&sigProto.Message{ err := clientA.Send(&sigProto.Message{
@ -100,11 +112,15 @@ var _ = Describe("Client", func() {
key, _ := wgtypes.GenerateKey() key, _ := wgtypes.GenerateKey()
client := createSignalClient(addr, key) client := createSignalClient(addr, key)
client.Receive(func(msg *sigProto.Message) error { go func() {
return nil err := client.Receive(func(msg *sigProto.Message) error {
}) return nil
client.WaitConnected() })
if err != nil {
return
}
}()
client.WaitStreamConnected()
Expect(client).NotTo(BeNil()) Expect(client).NotTo(BeNil())
}) })
}) })