fix: graceful shutdown (#134)

* fix: graceful shutdown

* fix: windows graceful shutdown
This commit is contained in:
Mikhail Bragin 2021-10-17 22:15:38 +02:00 committed by GitHub
parent fcea3c99d4
commit bef3b3392b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 79 additions and 43 deletions

View File

@ -31,7 +31,8 @@ var (
} }
// Execution control channel for stopCh signal // Execution control channel for stopCh signal
stopCh chan int stopCh chan int
cleanupCh chan struct{}
) )
// Execute executes the root command. // Execute executes the root command.
@ -41,6 +42,7 @@ func Execute() error {
func init() { func init() {
stopCh = make(chan int) stopCh = make(chan int)
cleanupCh = make(chan struct{})
defaultConfigPath = "/etc/wiretrustee/config.json" defaultConfigPath = "/etc/wiretrustee/config.json"
defaultLogFile = "/var/log/wiretrustee/client.log" defaultLogFile = "/var/log/wiretrustee/client.log"

View File

@ -5,6 +5,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/wiretrustee/wiretrustee/util" "github.com/wiretrustee/wiretrustee/util"
"time"
) )
func (p *program) Start(s service.Service) error { func (p *program) Start(s service.Service) error {
@ -15,12 +16,19 @@ func (p *program) Start(s service.Service) error {
if err != nil { if err != nil {
return return
} }
}() }()
return nil return nil
} }
func (p *program) Stop(s service.Service) error { func (p *program) Stop(s service.Service) error {
stopCh <- 1
select {
case <-cleanupCh:
case <-time.After(time.Second * 10):
log.Warnf("failed waiting for service cleanup, terminating")
}
log.Info("stopped Wiretrustee service") //nolint
return nil return nil
} }
@ -29,11 +37,15 @@ var (
Use: "run", Use: "run",
Short: "runs wiretrustee as service", Short: "runs wiretrustee as service",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
err := util.InitLog(logLevel, logFile) err := util.InitLog(logLevel, logFile)
if err != nil { if err != nil {
log.Errorf("failed initializing log %v", err) log.Errorf("failed initializing log %v", err)
return return
} }
SetupCloseHandler()
prg := &program{ prg := &program{
cmd: cmd, cmd: cmd,
args: args, args: args,

View File

@ -9,7 +9,6 @@ import (
mgm "github.com/wiretrustee/wiretrustee/management/client" mgm "github.com/wiretrustee/wiretrustee/management/client"
mgmProto "github.com/wiretrustee/wiretrustee/management/proto" mgmProto "github.com/wiretrustee/wiretrustee/management/proto"
signal "github.com/wiretrustee/wiretrustee/signal/client" signal "github.com/wiretrustee/wiretrustee/signal/client"
"github.com/wiretrustee/wiretrustee/util"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -20,13 +19,8 @@ var (
Use: "up", Use: "up",
Short: "install, login and start wiretrustee client", Short: "install, login and start wiretrustee client",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
err := util.InitLog(logLevel, logFile)
if err != nil {
log.Errorf("failed initializing log %v", err)
return err
}
err = loginCmd.RunE(cmd, args) err := loginCmd.RunE(cmd, args)
if err != nil { if err != nil {
return err return err
} }
@ -117,7 +111,7 @@ func connectToManagement(ctx context.Context, managementAddr string, ourPrivateK
} }
} }
log.Infof("peer logged in to Management Service %s", managementAddr) log.Debugf("peer logged in to Management Service %s", managementAddr)
return client, loginResp, nil return client, loginResp, nil
} }
@ -166,7 +160,7 @@ func runClient() error {
} }
// create start the Wiretrustee Engine that will connect to the Signal and Management streams and manage connections to remote peers. // create start the Wiretrustee Engine that will connect to the Signal and Management streams and manage connections to remote peers.
engine := internal.NewEngine(signalClient, mgmClient, engineConfig, cancel) engine := internal.NewEngine(signalClient, mgmClient, engineConfig, cancel, ctx)
err = engine.Start() err = engine.Start()
if err != nil { if err != nil {
log.Errorf("error while starting Wiretrustee Connection Engine: %s", err) log.Errorf("error while starting Wiretrustee Connection Engine: %s", err)
@ -175,14 +169,11 @@ func runClient() error {
log.Print("Wiretrustee engine started, my IP is: ", peerConfig.Address) log.Print("Wiretrustee engine started, my IP is: ", peerConfig.Address)
SetupCloseHandler()
select { select {
case <-stopCh: case <-stopCh:
case <-ctx.Done(): case <-ctx.Done():
} }
log.Info("shutting down Wiretrustee client")
err = mgmClient.Close() err = mgmClient.Close()
if err != nil { if err != nil {
log.Errorf("failed closing Management Service client %v", err) log.Errorf("failed closing Management Service client %v", err)
@ -200,5 +191,8 @@ func runClient() error {
return err return err
} }
log.Info("stopped Wiretrustee client")
cleanupCh <- struct{}{}
return nil return nil
} }

View File

@ -166,7 +166,7 @@ func (conn *Connection) Open(timeout time.Duration) error {
select { select {
case remoteAuth := <-conn.remoteAuthChannel: case remoteAuth := <-conn.remoteAuthChannel:
log.Infof("got a connection confirmation from peer %s", conn.Config.RemoteWgKey.String()) log.Debugf("got a connection confirmation from peer %s", conn.Config.RemoteWgKey.String())
err = conn.agent.GatherCandidates() err = conn.agent.GatherCandidates()
if err != nil { if err != nil {
@ -186,8 +186,11 @@ func (conn *Connection) Open(timeout time.Duration) error {
if err != nil { if err != nil {
return err return err
} }
useProxy := useProxy(pair)
// in case the remote peer is in the local network or one of the peers has public static IP -> no need for a Wireguard proxy, direct communication is possible. // in case the remote peer is in the local network or one of the peers has public static IP -> no need for a Wireguard proxy, direct communication is possible.
if !useProxy(pair) { if !useProxy {
log.Debugf("it is possible to establish a direct connection (without proxy) to peer %s - my addr: %s, remote addr: %s", conn.Config.RemoteWgKey.String(), pair.Local, pair.Remote) log.Debugf("it is possible to establish a direct connection (without proxy) to peer %s - my addr: %s, remote addr: %s", conn.Config.RemoteWgKey.String(), pair.Local, pair.Remote)
err = conn.wgProxy.StartLocal(fmt.Sprintf("%s:%d", pair.Remote.Address(), iface.WgPort)) err = conn.wgProxy.StartLocal(fmt.Sprintf("%s:%d", pair.Remote.Address(), iface.WgPort))
if err != nil { if err != nil {
@ -195,19 +198,17 @@ func (conn *Connection) Open(timeout time.Duration) error {
} }
} else { } else {
log.Infof("establishing secure tunnel to peer %s via selected candidate pair %s", conn.Config.RemoteWgKey.String(), pair) log.Debugf("establishing secure tunnel to peer %s via selected candidate pair %s", conn.Config.RemoteWgKey.String(), pair)
err = conn.wgProxy.Start(remoteConn) err = conn.wgProxy.Start(remoteConn)
if err != nil { if err != nil {
return err return err
} }
} }
if pair.Remote.Type() == ice.CandidateTypeRelay || pair.Local.Type() == ice.CandidateTypeRelay { relayed := pair.Remote.Type() == ice.CandidateTypeRelay || pair.Local.Type() == ice.CandidateTypeRelay
log.Infof("using relay with peer %s", conn.Config.RemoteWgKey)
}
conn.Status = StatusConnected conn.Status = StatusConnected
log.Infof("opened connection to peer %s", conn.Config.RemoteWgKey.String()) log.Infof("opened connection to peer %s [localProxy=%v, relayed=%v]", conn.Config.RemoteWgKey.String(), useProxy, relayed)
case <-conn.closeCond.C: case <-conn.closeCond.C:
conn.Status = StatusDisconnected conn.Status = StatusDisconnected
return fmt.Errorf("connection to peer %s has been closed", conn.Config.RemoteWgKey.String()) return fmt.Errorf("connection to peer %s has been closed", conn.Config.RemoteWgKey.String())
@ -271,7 +272,7 @@ func (conn *Connection) Close() error {
var err error var err error
conn.closeCond.Do(func() { conn.closeCond.Do(func() {
log.Warnf("closing connection to peer %s", conn.Config.RemoteWgKey.String()) log.Debugf("closing connection to peer %s", conn.Config.RemoteWgKey.String())
if a := conn.agent; a != nil { if a := conn.agent; a != nil {
e := a.Close() e := a.Close()

View File

@ -57,6 +57,8 @@ type Engine struct {
TURNs []*ice.URL TURNs []*ice.URL
cancel context.CancelFunc cancel context.CancelFunc
ctx context.Context
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@ -66,7 +68,7 @@ type Peer struct {
} }
// NewEngine creates a new Connection Engine // NewEngine creates a new Connection Engine
func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *EngineConfig, cancel context.CancelFunc) *Engine { func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *EngineConfig, cancel context.CancelFunc, ctx context.Context) *Engine {
return &Engine{ return &Engine{
signal: signalClient, signal: signalClient,
mgmClient: mgmClient, mgmClient: mgmClient,
@ -77,17 +79,25 @@ func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *Engin
STUNs: []*ice.URL{}, STUNs: []*ice.URL{},
TURNs: []*ice.URL{}, TURNs: []*ice.URL{},
cancel: cancel, cancel: cancel,
ctx: ctx,
} }
} }
func (e *Engine) Stop() error { func (e *Engine) Stop() error {
err := e.removeAllPeerConnections()
if err != nil {
return err
}
log.Debugf("removing Wiretrustee interface %s", e.config.WgIface) log.Debugf("removing Wiretrustee interface %s", e.config.WgIface)
err := iface.Close() err = iface.Close()
if err != nil { if err != nil {
log.Errorf("failed closing Wiretrustee interface %s %v", e.config.WgIface, err) log.Errorf("failed closing Wiretrustee interface %s %v", e.config.WgIface, err)
return err return err
} }
log.Infof("stopped Wiretrustee Engine")
return nil return nil
} }
@ -127,7 +137,7 @@ func (e *Engine) Start() error {
// initializePeer peer agent attempt to open connection // initializePeer peer agent attempt to open connection
func (e *Engine) initializePeer(peer Peer) { func (e *Engine) initializePeer(peer Peer) {
var backOff = &backoff.ExponentialBackOff{ var backOff = backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: backoff.DefaultInitialInterval, InitialInterval: backoff.DefaultInitialInterval,
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier, Multiplier: backoff.DefaultMultiplier,
@ -135,13 +145,14 @@ func (e *Engine) initializePeer(peer Peer) {
MaxElapsedTime: time.Duration(0), //never stop MaxElapsedTime: time.Duration(0), //never stop
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }, e.ctx)
operation := func() error { operation := func() error {
_, err := e.openPeerConnection(e.wgPort, e.config.WgPrivateKey, peer) _, err := e.openPeerConnection(e.wgPort, e.config.WgPrivateKey, peer)
e.peerMux.Lock() e.peerMux.Lock()
defer e.peerMux.Unlock() defer e.peerMux.Unlock()
if _, ok := e.conns[peer.WgPubKey]; !ok { if _, ok := e.conns[peer.WgPubKey]; !ok {
log.Infof("removing connection attempt with Peer: %v, not retrying", peer.WgPubKey) log.Debugf("removed connection attempt to peer: %v, not retrying", peer.WgPubKey)
return nil return nil
} }
@ -172,6 +183,19 @@ func (e *Engine) removePeerConnections(peers []string) error {
return nil return nil
} }
func (e *Engine) removeAllPeerConnections() error {
log.Debugf("removing all peer connections")
e.peerMux.Lock()
defer e.peerMux.Unlock()
for peer := range e.conns {
err := e.removePeerConnection(peer)
if err != nil {
return err
}
}
return nil
}
// removePeerConnection closes existing peer connection and removes peer // removePeerConnection closes existing peer connection and removes peer
func (e *Engine) removePeerConnection(peerKey string) error { func (e *Engine) removePeerConnection(peerKey string) error {
conn, exists := e.conns[peerKey] conn, exists := e.conns[peerKey]
@ -179,6 +203,7 @@ func (e *Engine) removePeerConnection(peerKey string) error {
delete(e.conns, peerKey) delete(e.conns, peerKey)
return conn.Close() return conn.Close()
} }
log.Infof("removed connection to peer %s", peerKey)
return nil return nil
} }
@ -310,7 +335,7 @@ func (e *Engine) receiveManagementEvents() {
e.cancel() e.cancel()
return return
} }
log.Infof("connected to Management Service updates stream") log.Debugf("stopped receiving updates from Management Service")
}() }()
log.Debugf("connecting to Management Service updates stream") log.Debugf("connecting to Management Service updates stream")
} }

View File

@ -87,7 +87,7 @@ func (p *WgProxy) proxyToRemotePeer(remoteConn *ice.Conn) {
for { for {
select { select {
case <-p.close: case <-p.close:
log.Infof("stopped proxying from remote peer %s due to closed connection", p.remoteKey) log.Debugf("stopped proxying from remote peer %s due to closed connection", p.remoteKey)
return return
default: default:
n, err := p.wgConn.Read(buf) n, err := p.wgConn.Read(buf)
@ -113,7 +113,7 @@ func (p *WgProxy) proxyToLocalWireguard(remoteConn *ice.Conn) {
for { for {
select { select {
case <-p.close: case <-p.close:
log.Infof("stopped proxying from remote peer %s due to closed connection", p.remoteKey) log.Debugf("stopped proxying from remote peer %s due to closed connection", p.remoteKey)
return return
default: default:
n, err := remoteConn.Read(buf) n, err := remoteConn.Read(buf)

View File

@ -65,8 +65,8 @@ func (c *Client) Close() error {
} }
//defaultBackoff is a basic backoff mechanism for general issues //defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff() backoff.BackOff { func defaultBackoff(ctx context.Context) backoff.BackOff {
return &backoff.ExponentialBackOff{ return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier, Multiplier: backoff.DefaultMultiplier,
@ -74,14 +74,14 @@ func defaultBackoff() backoff.BackOff {
MaxElapsedTime: 24 * 3 * time.Hour, //stop after 3 days trying MaxElapsedTime: 24 * 3 * time.Hour, //stop after 3 days trying
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }, ctx)
} }
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function // Blocking request. The result will be sent via msgHandler callback function
func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error { func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
var backOff = defaultBackoff() var backOff = defaultBackoff(c.ctx)
operation := func() error { operation := func() error {
@ -114,7 +114,7 @@ func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
err := backoff.Retry(operation, backOff) err := backoff.Retry(operation, backOff)
if err != nil { if err != nil {
log.Errorf("exiting Management Service connection retry loop due to unrecoverable error %s ", err) log.Warnf("exiting Management Service connection retry loop due to unrecoverable error: %s", err)
return err return err
} }
@ -145,7 +145,7 @@ func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, server
return err return err
} }
if err != nil { if err != nil {
log.Errorf("disconnected from Management Service sync stream: %v", err) log.Warnf("disconnected from Management Service sync stream: %v", err)
return err return err
} }
@ -159,7 +159,7 @@ func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, server
err = msgHandler(decryptedResp) err = msgHandler(decryptedResp)
if err != nil { if err != nil {
log.Errorf("failed handling an update message received from Management Service %v", err.Error()) log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
return err return err
} }
} }

View File

@ -76,8 +76,8 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
} }
//defaultBackoff is a basic backoff mechanism for general issues //defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff() backoff.BackOff { func defaultBackoff(ctx context.Context) backoff.BackOff {
return &backoff.ExponentialBackOff{ return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier, Multiplier: backoff.DefaultMultiplier,
@ -85,7 +85,8 @@ func defaultBackoff() backoff.BackOff {
MaxElapsedTime: 24 * 3 * time.Hour, //stop after 3 days trying MaxElapsedTime: 24 * 3 * time.Hour, //stop after 3 days trying
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }, ctx)
} }
// Receive Connects to the Signal Exchange message stream and starts receiving messages. // Receive Connects to the Signal Exchange message stream and starts receiving messages.
@ -96,12 +97,13 @@ func (c *Client) Receive(msgHandler func(msg *proto.Message) error) {
c.connWg.Add(1) c.connWg.Add(1)
go func() { go func() {
var backOff = defaultBackoff() var backOff = defaultBackoff(c.ctx)
operation := func() error { operation := func() error {
err := c.connect(c.key.PublicKey().String(), msgHandler) err := c.connect(c.key.PublicKey().String(), msgHandler)
if err != nil { if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error %s. Retrying ... ", err) log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
c.connWg.Add(1) c.connWg.Add(1)
return err return err
} }
@ -112,7 +114,7 @@ func (c *Client) Receive(msgHandler func(msg *proto.Message) error) {
err := backoff.Retry(operation, backOff) err := backoff.Retry(operation, backOff)
if err != nil { if err != nil {
log.Errorf("error while communicating with the Signal Exchange %s ", err) log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err)
return return
} }
}() }()