Add reconnect logic

This commit is contained in:
Zoltán Papp 2024-07-11 14:37:22 +02:00
parent 4e75e15ea1
commit ea93a5edd3

View File

@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/cenkalti/backoff/v4"
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@ -102,8 +103,9 @@ type Conn struct {
endpointRelay *net.UDPAddr endpointRelay *net.UDPAddr
iCEDisconnected chan struct{} // for reconnection operations
relayDisconnected chan struct{} iCEDisconnected chan bool
relayDisconnected chan bool
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
@ -130,8 +132,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
allowedIPsIP: allowedIPsIP.String(), allowedIPsIP: allowedIPsIP.String(),
statusRelay: StatusDisconnected, statusRelay: StatusDisconnected,
statusICE: StatusDisconnected, statusICE: StatusDisconnected,
iCEDisconnected: make(chan struct{}), iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan struct{}), relayDisconnected: make(chan bool, 1),
} }
rFns := WorkerRelayCallbacks{ rFns := WorkerRelayCallbacks{
@ -184,14 +186,18 @@ func (conn *Conn) Open() {
conn.log.Warnf("error while updating the state err: %v", err) conn.log.Warnf("error while updating the state err: %v", err)
} }
conn.waitRandomSleepTime() conn.waitInitialRandomSleepTime()
err = conn.doHandshake() err = conn.doHandshake()
if err != nil { if err != nil {
conn.log.Errorf("failed to send offer: %v", err) conn.log.Errorf("failed to send offer: %v", err)
} }
go conn.reconnectLoop() if conn.workerRelay.IsController() {
go conn.reconnectLoopWithRetry()
} else {
go conn.reconnectLoopForOnDisconnectedEvent()
}
} }
// Close closes this peer Conn issuing a close event to the Conn closeCh // Close closes this peer Conn issuing a close event to the Conn closeCh
@ -310,30 +316,75 @@ func (conn *Conn) GetKey() string {
return conn.config.Key return conn.config.Key
} }
func (conn *Conn) reconnectLoop() { func (conn *Conn) reconnectLoopWithRetry() {
ticker := time.NewTicker(conn.config.Timeout) // wait for the initial connection to be established
if !conn.workerRelay.IsController() { select {
ticker.Stop() case <-conn.ctx.Done():
} else { case <-time.After(3 * time.Second):
defer ticker.Stop()
} }
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0,
Multiplier: 1.7,
MaxInterval: conn.config.Timeout * time.Second,
MaxElapsedTime: 0,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, conn.ctx)
ticker := backoff.NewTicker(bo)
defer ticker.Stop()
no := time.Now()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
// checks if there is peer connection is established via relay or ice and that it has a wireguard handshake and skip offer // checks if there is peer connection is established via relay or ice and that it has a wireguard handshake and skip offer
// todo check wg handshake // todo check wg handshake
conn.log.Tracef("ticker timedout, relay state: %s, ice state: %s, elapsed time: %s", conn.statusRelay, conn.statusICE, time.Since(no))
no = time.Now()
if conn.statusRelay == StatusConnected && conn.statusICE == StatusConnected { if conn.statusRelay == StatusConnected && conn.statusICE == StatusConnected {
continue continue
} }
case <-conn.relayDisconnected:
conn.log.Debugf("Relay connection is disconnected, start to send new offer") log.Debugf("ticker timed out, retry to do handshake")
ticker.Reset(10 * time.Second) err := conn.doHandshake()
conn.waitRandomSleepTime() if err != nil {
case <-conn.iCEDisconnected: conn.log.Errorf("failed to do handshake: %v", err)
conn.log.Debugf("ICE connection is disconnected, start to send new offer") }
ticker.Reset(10 * time.Second) case changed := <-conn.relayDisconnected:
conn.waitRandomSleepTime() if !changed {
continue
}
conn.log.Debugf("Relay state changed, reset reconnect timer")
bo.Reset()
case changed := <-conn.iCEDisconnected:
if !changed {
continue
}
conn.log.Debugf("ICE state changed, reset reconnect timer")
bo.Reset()
case <-conn.ctx.Done():
return
}
}
}
func (conn *Conn) reconnectLoopForOnDisconnectedEvent() {
for {
select {
case changed := <-conn.relayDisconnected:
if !changed {
continue
}
conn.log.Debugf("Relay state changed, try to send new offer")
case changed := <-conn.iCEDisconnected:
if !changed {
continue
}
conn.log.Debugf("ICE state changed, try to send new offer")
case <-conn.ctx.Done(): case <-conn.ctx.Done():
return return
} }
@ -438,10 +489,12 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.log.Tracef("ICE connection state changed to %s", newState) conn.log.Tracef("ICE connection state changed to %s", newState)
defer func() { defer func() {
changed := conn.statusICE != newState && newState != StatusConnecting
conn.statusICE = newState conn.statusICE = newState
select { select {
case conn.iCEDisconnected <- struct{}{}: case conn.iCEDisconnected <- changed:
default: default:
} }
}() }()
@ -542,10 +595,11 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
defer func() { defer func() {
changed := conn.statusRelay != StatusDisconnected
conn.statusRelay = StatusDisconnected conn.statusRelay = StatusDisconnected
select { select {
case conn.relayDisconnected <- struct{}{}: case conn.relayDisconnected <- changed:
default: default:
} }
}() }()
@ -619,11 +673,12 @@ func (conn *Conn) doHandshake() error {
if err == nil { if err == nil {
ha.RelayAddr = addr ha.RelayAddr = addr
} }
conn.log.Tracef("send new offer: %#v", ha)
conn.log.Tracef("do handshake with args: %v", ha)
return conn.handshaker.SendOffer(ha) return conn.handshaker.SendOffer(ha)
} }
func (conn *Conn) waitRandomSleepTime() { func (conn *Conn) waitInitialRandomSleepTime() {
minWait := 500 minWait := 500
maxWait := 2000 maxWait := 2000
duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond