Fix logic

This commit is contained in:
Zoltán Papp 2024-06-25 15:13:08 +02:00
parent f72e852ccb
commit 0a67f5be1a
12 changed files with 345 additions and 291 deletions

View File

@ -2,6 +2,7 @@ package peer
import ( import (
"context" "context"
"math/rand"
"net" "net"
"runtime" "runtime"
"strings" "strings"
@ -26,8 +27,8 @@ const (
defaultWgKeepAlive = 25 * time.Second defaultWgKeepAlive = 25 * time.Second
connPriorityRelay ConnPriority = 1 connPriorityRelay ConnPriority = 1
connPriorityICETurn = 1 connPriorityICETurn ConnPriority = 1
connPriorityICEP2P = 2 connPriorityICEP2P ConnPriority = 2
) )
type WgConfig struct { type WgConfig struct {
@ -69,7 +70,6 @@ type WorkerCallbacks struct {
OnICEConnReadyCallback func(ConnPriority, ICEConnInfo) OnICEConnReadyCallback func(ConnPriority, ICEConnInfo)
OnICEStatusChanged func(ConnStatus) OnICEStatusChanged func(ConnStatus)
DoHandshake func(*OfferAnswer, error)
} }
type Conn struct { type Conn struct {
@ -83,6 +83,7 @@ type Conn struct {
wgProxyICE wgproxy.Proxy wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy wgProxyRelay wgproxy.Proxy
signaler *Signaler signaler *Signaler
relayManager *relayClient.Manager
allowedIPsIP string allowedIPsIP string
handshaker *Handshaker handshaker *Handshaker
@ -101,6 +102,9 @@ type Conn struct {
afterRemovePeerHooks []AfterRemovePeerHookFunc afterRemovePeerHooks []AfterRemovePeerHookFunc
endpointRelay *net.UDPAddr endpointRelay *net.UDPAddr
iCEDisconnected chan struct{}
relayDisconnected chan struct{}
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
@ -123,25 +127,29 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgProxyFactory: wgProxyFactory, wgProxyFactory: wgProxyFactory,
signaler: signaler, signaler: signaler,
relayManager: relayManager,
allowedIPsIP: allowedIPsIP.String(), allowedIPsIP: allowedIPsIP.String(),
statusRelay: StatusDisconnected, statusRelay: StatusDisconnected,
statusICE: StatusDisconnected, statusICE: StatusDisconnected,
iCEDisconnected: make(chan struct{}),
relayDisconnected: make(chan struct{}),
} }
rFns := WorkerRelayCallbacks{ rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady, OnConnReady: conn.relayConnectionIsReady,
OnStatusChanged: conn.onWorkerRelayStateChanged, OnDisconnected: conn.onWorkerRelayStateDisconnected,
} }
wFns := WorkerICECallbacks{ wFns := WorkerICECallbacks{
OnConnReady: conn.iCEConnectionIsReady, OnConnReady: conn.iCEConnectionIsReady,
OnStatusChanged: conn.onWorkerICEStateChanged, OnStatusChanged: conn.onWorkerICEStateDisconnected,
DoHandshake: conn.doHandshake,
} }
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler) conn.handshaker = NewHandshaker(ctx, connLog, config, signaler)
conn.workerRelay = NewWorkerRelay(ctx, connLog, config, relayManager, rFns) conn.workerRelay = NewWorkerRelay(ctx, connLog, config, relayManager, rFns)
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, config.ICEConfig, signaler, iFaceDiscover, statusRecorder, wFns)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, config.ICEConfig, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -172,14 +180,21 @@ func (conn *Conn) Open() {
conn.log.Warnf("error while updating the state err: %v", err) conn.log.Warnf("error while updating the state err: %v", err)
} }
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() conn.waitRandomSleepTime()
go conn.workerICE.SetupICEConnection(relayIsSupportedLocally)
err = conn.doHandshake()
if err != nil {
conn.log.Errorf("failed to send offer: %v", err)
}
go conn.reconnectLoop()
} }
// Close closes this peer Conn issuing a close event to the Conn closeCh // Close closes this peer Conn issuing a close event to the Conn closeCh
func (conn *Conn) Close() { func (conn *Conn) Close() {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.ctxCancel() conn.ctxCancel()
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
@ -198,7 +213,6 @@ func (conn *Conn) Close() {
conn.wgProxyICE = nil conn.wgProxyICE = nil
} }
// todo: is it problem if we try to remove a peer what is never existed?
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
if err != nil { if err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.log.Errorf("failed to remove wg endpoint: %v", err)
@ -268,7 +282,7 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string)
} }
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
conn.log.Debugf("OnRemoteOffer, on status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay) conn.log.Debugf("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer) return conn.handshaker.OnRemoteOffer(offer)
} }
@ -288,136 +302,39 @@ func (conn *Conn) GetKey() string {
return conn.config.Key return conn.config.Key
} }
func (conn *Conn) onWorkerICEStateChanged(newState ConnStatus) { func (conn *Conn) reconnectLoop() {
conn.mu.Lock() ticker := time.NewTicker(conn.config.Timeout) // todo use the interval from config
defer conn.mu.Unlock() if !conn.workerRelay.IsController() {
log.Debugf("ICE connection state changed to %s", newState) ticker.Stop()
defer func() { } else {
conn.statusICE = newState defer ticker.Stop()
}() }
if conn.statusRelay == StatusConnected { for {
select {
case <-ticker.C:
// checks if there is peer connection is established via relay or ice and that it has a wireguard handshake and skip offer
// todo check wg handshake
if conn.statusRelay == StatusConnected && conn.statusICE == StatusConnected {
continue
}
case <-conn.relayDisconnected:
conn.log.Debugf("Relay connection is disconnected, start to send new offer")
ticker.Reset(10 * time.Second)
conn.waitRandomSleepTime()
case <-conn.iCEDisconnected:
conn.log.Debugf("ICE connection is disconnected, start to send new offer")
ticker.Reset(10 * time.Second)
conn.waitRandomSleepTime()
case <-conn.ctx.Done():
return return
} }
if conn.evalStatus() == newState { err := conn.doHandshake()
return
}
if conn.endpointRelay != nil {
err := conn.configureWGEndpoint(conn.endpointRelay)
if err != nil { if err != nil {
conn.log.Errorf("failed to switch back to relay conn: %v", err) conn.log.Errorf("failed to do handshake: %v", err)
}
// todo update status to relay related things
log.Debugf("switched back to relay connection")
return
}
if newState > conn.statusICE {
peerState := State{
PubKey: conn.config.Key,
ConnStatus: newState,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
_ = conn.statusRecorder.UpdatePeerState(peerState)
}
}
func (conn *Conn) onWorkerRelayStateChanged(newState ConnStatus) {
conn.mu.Lock()
defer conn.mu.Unlock()
defer func() {
conn.statusRelay = newState
}()
conn.log.Debugf("Relay connection state changed to %s", newState)
if conn.statusICE == StatusConnected {
return
}
if conn.evalStatus() == newState {
return
}
if newState > conn.statusRelay {
peerState := State{
PubKey: conn.config.Key,
ConnStatus: newState,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
_ = conn.statusRecorder.UpdatePeerState(peerState)
}
}
func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
return
}
conn.log.Debugf("relay connection is ready")
conn.statusRelay = stateConnected
// todo review this condition
if conn.currentConnType > connPriorityRelay {
if conn.statusICE == StatusConnected {
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnType)
return
} }
} }
if conn.currentConnType != 0 {
conn.log.Infof("update connection to Relay type")
}
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
conn.connID = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
}
err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err)
}
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
return
}
conn.endpointRelay = endpointUdpAddr
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close depracated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = wgProxy
conn.currentConnType = connPriorityRelay
peerState := State{
Direct: false,
Relayed: true,
}
conn.updateStatus(peerState, rci.rosenpassPubKey, rci.rosenpassAddr)
} }
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
@ -431,9 +348,8 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready") conn.log.Debugf("ICE connection is ready")
conn.statusICE = stateConnected conn.statusICE = StatusConnected
// todo review this condition
if conn.currentConnType > priority { if conn.currentConnType > priority {
return return
} }
@ -503,6 +419,150 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.updateStatus(peerState, iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.updateStatus(peerState, iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
} }
// todo review to make sense to handle connection and disconnected status also?
func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.mu.Lock()
defer conn.mu.Unlock()
conn.log.Tracef("ICE connection state changed to %s", newState)
defer func() {
conn.statusICE = newState
select {
case conn.iCEDisconnected <- struct{}{}:
default:
}
}()
// switch back to relay connection
if conn.endpointRelay != nil {
conn.log.Debugf("ICE disconnected, set Relay to active connection")
err := conn.configureWGEndpoint(conn.endpointRelay)
if err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
// todo update status to relay related things
return
}
if conn.statusRelay == StatusConnected {
return
}
if conn.evalStatus() == newState {
return
}
if newState > conn.statusICE {
peerState := State{
PubKey: conn.config.Key,
ConnStatus: newState,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
_ = conn.statusRecorder.UpdatePeerState(peerState)
}
}
func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
return
}
conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay = StatusConnected
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
if conn.currentConnType > connPriorityRelay {
if conn.statusICE == StatusConnected {
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnType)
return
}
}
conn.connID = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
}
err = conn.configureWGEndpoint(endpointUdpAddr)
if err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err)
}
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
return
}
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close depracated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = wgProxy
conn.currentConnType = connPriorityRelay
peerState := State{
Direct: false,
Relayed: true,
}
conn.log.Infof("start to communicate with peer via relay")
conn.updateStatus(peerState, rci.rosenpassPubKey, rci.rosenpassAddr)
}
func (conn *Conn) onWorkerRelayStateDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
defer func() {
conn.statusRelay = StatusDisconnected
select {
case conn.relayDisconnected <- struct{}{}:
default:
}
}()
if conn.wgProxyRelay != nil {
conn.endpointRelay = nil
_ = conn.wgProxyRelay.CloseConn()
conn.wgProxyRelay = nil
}
if conn.statusICE == StatusConnected {
return
}
if conn.evalStatus() == StatusDisconnected {
return
}
if StatusDisconnected > conn.statusRelay {
peerState := State{
PubKey: conn.config.Key,
ConnStatus: StatusDisconnected,
ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex),
}
_ = conn.statusRecorder.UpdatePeerState(peerState)
}
}
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error { func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
return conn.config.WgConfig.WgInterface.UpdatePeer( return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey, conn.config.WgConfig.RemoteKey,
@ -531,7 +591,6 @@ func (conn *Conn) updateStatus(peerState State, remoteRosenpassPubKey []byte, re
if conn.onConnected != nil { if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr) conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr)
} }
return
} }
func (conn *Conn) doHandshake() error { func (conn *Conn) doHandshake() error {
@ -548,9 +607,24 @@ func (conn *Conn) doHandshake() error {
if err == nil { if err == nil {
ha.RelayAddr = addr.String() ha.RelayAddr = addr.String()
} }
conn.log.Tracef("send new offer: %#v", ha)
return conn.handshaker.SendOffer(ha) return conn.handshaker.SendOffer(ha)
} }
func (conn *Conn) waitRandomSleepTime() {
minWait := 500
maxWait := 2000
duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond
timeout := time.NewTimer(duration)
defer timeout.Stop()
select {
case <-conn.ctx.Done():
case <-timeout.C:
}
}
func (conn *Conn) evalStatus() ConnStatus { func (conn *Conn) evalStatus() ConnStatus {
if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected { if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected {
return StatusConnected return StatusConnected

View File

@ -4,17 +4,12 @@ import (
"context" "context"
"errors" "errors"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
const (
handshakeCacheTimeout = 3 * time.Second
)
var ( var (
ErrSignalIsNotReady = errors.New("signal is not ready") ErrSignalIsNotReady = errors.New("signal is not ready")
) )
@ -65,9 +60,6 @@ type Handshaker struct {
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
remoteAnswerCh chan OfferAnswer remoteAnswerCh chan OfferAnswer
remoteOfferAnswer *OfferAnswer
remoteOfferAnswerCreated time.Time
lastOfferArgs HandshakeArgs lastOfferArgs HandshakeArgs
} }
@ -88,6 +80,7 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
func (h *Handshaker) Listen() { func (h *Handshaker) Listen() {
for { for {
log.Debugf("wait for remote offer confirmation")
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation() remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
if err != nil { if err != nil {
if _, ok := err.(*ConnectionClosedError); ok { if _, ok := err.(*ConnectionClosedError); ok {

View File

@ -3,7 +3,6 @@ package peer
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand"
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
@ -68,7 +67,6 @@ type ICEConnInfo struct {
type WorkerICECallbacks struct { type WorkerICECallbacks struct {
OnConnReady func(ConnPriority, ICEConnInfo) OnConnReady func(ConnPriority, ICEConnInfo)
OnStatusChanged func(ConnStatus) OnStatusChanged func(ConnStatus)
DoHandshake func() error
} }
type WorkerICE struct { type WorkerICE struct {
@ -79,6 +77,7 @@ type WorkerICE struct {
signaler *Signaler signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
statusRecorder *Status statusRecorder *Status
hasRelayOnLocally bool
conn WorkerICECallbacks conn WorkerICECallbacks
selectedPriority ConnPriority selectedPriority ConnPriority
@ -92,13 +91,9 @@ type WorkerICE struct {
localUfrag string localUfrag string
localPwd string localPwd string
creadantialHasUsed bool
hasRelayOnLocally bool
tickerCancel context.CancelFunc
ticker *time.Ticker
} }
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, configICE ICEConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, callBacks WorkerICECallbacks) (*WorkerICE, error) { func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, configICE ICEConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
w := &WorkerICE{ w := &WorkerICE{
ctx: ctx, ctx: ctx,
log: log, log: log,
@ -107,6 +102,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, config
signaler: signaler, signaler: signaler,
iFaceDiscover: ifaceDiscover, iFaceDiscover: ifaceDiscover,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally,
conn: callBacks, conn: callBacks,
} }
@ -119,62 +115,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, config
return w, nil return w, nil
} }
func (w *WorkerICE) SetupICEConnection(hasRelayOnLocally bool) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent != nil {
return
}
w.hasRelayOnLocally = hasRelayOnLocally
go w.sendOffer()
}
func (w *WorkerICE) sendOffer() {
w.ticker = time.NewTicker(w.config.Timeout)
defer w.ticker.Stop()
tickerCtx, tickerCancel := context.WithCancel(w.ctx)
w.tickerCancel = tickerCancel
w.conn.OnStatusChanged(StatusConnecting)
w.log.Debugf("ICE trigger a new handshake")
err := w.conn.DoHandshake()
if err != nil {
w.log.Errorf("%s", err)
}
for {
w.log.Debugf("ICE trigger new reconnect handshake")
select {
case <-w.ticker.C:
err := w.conn.DoHandshake()
if err != nil {
w.log.Errorf("%s", err)
}
case <-tickerCtx.Done():
w.log.Debugf("left reconnect loop")
return
case <-w.ctx.Done():
return
}
}
}
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
log.Debugf("OnNewOffer for ICE") w.log.Debugf("OnNewOffer for ICE")
w.muxAgent.Lock() w.muxAgent.Lock()
if w.agent != nil { if w.agent != nil {
log.Debugf("agent already exists, skipping the offer") w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock() w.muxAgent.Unlock()
return return
} }
// cancel reconnection loop
w.log.Debugf("canceling reconnection loop")
w.tickerCancel()
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
w.selectedPriority = connPriorityICEP2P w.selectedPriority = connPriorityICEP2P
@ -184,7 +134,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
preferredCandidateTypes = candidateTypes() preferredCandidateTypes = candidateTypes()
} }
w.log.Debugf("recreate agent") w.log.Debugf("recreate ICE agent")
agentCtx, agentCancel := context.WithCancel(w.ctx) agentCtx, agentCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes) agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes)
if err != nil { if err != nil {
@ -204,14 +154,14 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
// will block until connection succeeded // will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state, // but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection // so we have to cancel it with the provided context once agent detected a broken connection
w.log.Debugf("turnAgentDial") w.log.Debugf("turn agent dial")
remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer) remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer)
if err != nil { if err != nil {
w.log.Debugf("failed to dial the remote peer: %s", err) w.log.Debugf("failed to dial the remote peer: %s", err)
return return
} }
w.log.Debugf("agent dial succeeded")
w.log.Debugf("GetSelectedCandidatePair")
pair, err := w.agent.GetSelectedCandidatePair() pair, err := w.agent.GetSelectedCandidatePair()
if err != nil { if err != nil {
return return
@ -240,7 +190,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
Relayed: isRelayed(pair), Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local), RelayedOnLocal: isRelayCandidate(pair.Local),
} }
w.log.Debugf("on conn ready") w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(w.selectedPriority, ci) go w.conn.OnConnReady(w.selectedPriority, ci)
} }
@ -325,7 +275,6 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport [
w.agent = nil w.agent = nil
w.muxAgent.Unlock() w.muxAgent.Unlock()
go w.sendOffer()
} }
}) })
if err != nil { if err != nil {
@ -430,23 +379,6 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA
} }
} }
// waitForReconnectTry waits for a random duration before trying to reconnect
func (w *WorkerICE) waitForReconnectTry() bool {
minWait := 500
maxWait := 2000
duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond
timeout := time.NewTimer(duration)
defer timeout.Stop()
select {
case <-w.ctx.Done():
return false
case <-timeout.C:
return true
}
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress() relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{

View File

@ -2,6 +2,7 @@ package peer
import ( import (
"context" "context"
"errors"
"net" "net"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -17,7 +18,7 @@ type RelayConnInfo struct {
type WorkerRelayCallbacks struct { type WorkerRelayCallbacks struct {
OnConnReady func(RelayConnInfo) OnConnReady func(RelayConnInfo)
OnStatusChanged func(ConnStatus) OnDisconnected func()
} }
type WorkerRelay struct { type WorkerRelay struct {
@ -41,9 +42,6 @@ func NewWorkerRelay(ctx context.Context, log *log.Entry, config ConnConfig, rela
func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
if !w.isRelaySupported(remoteOfferAnswer) { if !w.isRelaySupported(remoteOfferAnswer) {
w.log.Infof("Relay is not supported by remote peer") w.log.Infof("Relay is not supported by remote peer")
// todo should we retry?
// if the remote peer doesn't support relay make no sense to retry infinity
// but if the remote peer supports relay just the connection is lost we should retry
return return
} }
@ -55,12 +53,19 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
} }
srv := w.preferredRelayServer(currentRelayAddress.String(), remoteOfferAnswer.RelaySrvAddress) srv := w.preferredRelayServer(currentRelayAddress.String(), remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key, w.conn.OnDisconnected)
if err != nil { if err != nil {
// todo handle all type errors
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Infof("do not need to reopen relay connection")
return
}
w.log.Infof("do not need to reopen relay connection: %s", err) w.log.Infof("do not need to reopen relay connection: %s", err)
return return
} }
w.log.Debugf("Relay connection established with %s", srv)
go w.conn.OnConnReady(RelayConnInfo{ go w.conn.OnConnReady(RelayConnInfo{
relayedConn: relayedConn, relayedConn: relayedConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
@ -72,6 +77,14 @@ func (w *WorkerRelay) RelayAddress() (net.Addr, error) {
return w.relayManager.RelayAddress() return w.relayManager.RelayAddress()
} }
func (w *WorkerRelay) IsController() bool {
return w.config.LocalKey > w.config.Key
}
func (w *WorkerRelay) RelayIsSupportedLocally() bool {
return w.relayManager.HasRelayAddress()
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
if !w.relayManager.HasRelayAddress() { if !w.relayManager.HasRelayAddress() {
return false return false
@ -80,12 +93,8 @@ func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
} }
func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string { func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string {
if w.config.LocalKey > w.config.Key { if w.IsController() {
return myRelayAddress return myRelayAddress
} }
return remoteRelayAddress return remoteRelayAddress
} }
func (w *WorkerRelay) RelayIsSupportedLocally() bool {
return w.relayManager.HasRelayAddress()
}

View File

@ -73,6 +73,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
default: default:
n, err := p.localConn.Read(buf) n, err := p.localConn.Read(buf)
if err != nil { if err != nil {
log.Debugf("failed to read from wg interface conn: %s", err)
continue continue
} }
@ -80,6 +81,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
p.cancel() p.cancel()
} else {
log.Debugf("failed to write to remote conn: %s", err)
} }
continue continue
} }
@ -103,11 +106,13 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
p.cancel() p.cancel()
return return
} }
log.Errorf("failed to read from remote conn: %s", err)
continue continue
} }
_, err = p.localConn.Write(buf[:n]) _, err = p.localConn.Write(buf[:n])
if err != nil { if err != nil {
log.Debugf("failed to write to wg interface conn: %s", err)
continue continue
} }
} }

View File

@ -115,7 +115,7 @@ func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error { func (c *Client) Connect() error {
log.Infof("connecting to relay server: %s", c.serverAddress) c.log.Infof("connecting to relay server: %s", c.serverAddress)
c.readLoopMutex.Lock() c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock() defer c.readLoopMutex.Unlock()

View File

@ -21,7 +21,6 @@ func NewGuard(context context.Context, relayClient *Client) *Guard {
ctx: context, ctx: context,
relayClient: relayClient, relayClient: relayClient,
} }
return g return g
} }

View File

@ -44,6 +44,9 @@ type Manager struct {
relayClients map[string]*RelayTrack relayClients map[string]*RelayTrack
relayClientsMutex sync.RWMutex relayClientsMutex sync.RWMutex
onDisconnectedListeners map[string]map[*func()]struct{}
listenerLock sync.Mutex
} }
func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager { func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager {
@ -52,11 +55,11 @@ func NewManager(ctx context.Context, serverAddress string, peerID string) *Manag
srvAddress: serverAddress, srvAddress: serverAddress,
peerID: peerID, peerID: peerID,
relayClients: make(map[string]*RelayTrack), relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]map[*func()]struct{}),
} }
} }
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop. // Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop.
// todo: consider to return an error if the initial connection to the relay server is not established.
func (m *Manager) Serve() error { func (m *Manager) Serve() error {
if m.relayClient != nil { if m.relayClient != nil {
return fmt.Errorf("manager already serving") return fmt.Errorf("manager already serving")
@ -70,8 +73,9 @@ func (m *Manager) Serve() error {
} }
m.reconnectGuard = NewGuard(m.ctx, m.relayClient) m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(m.reconnectGuard.OnDisconnected) m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(m.srvAddress)
})
m.startCleanupLoop() m.startCleanupLoop()
return nil return nil
@ -80,7 +84,7 @@ func (m *Manager) Serve() error {
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new // established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. // connection to the relay server.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error) {
if m.relayClient == nil { if m.relayClient == nil {
return nil, errRelayClientNotConnected return nil, errRelayClientNotConnected
} }
@ -90,13 +94,26 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
return nil, err return nil, err
} }
var (
netConn net.Conn
)
if !foreign { if !foreign {
log.Debugf("open peer connection via permanent server: %s", peerKey) log.Debugf("open peer connection via permanent server: %s", peerKey)
return m.relayClient.OpenConn(peerKey) netConn, err = m.relayClient.OpenConn(peerKey)
} else { } else {
log.Debugf("open peer connection via foreign server: %s", serverAddress) log.Debugf("open peer connection via foreign server: %s", serverAddress)
return m.openConnVia(serverAddress, peerKey) netConn, err = m.openConnVia(serverAddress, peerKey)
} }
if err != nil {
return nil, err
}
if onClosedListener != nil {
m.addListener(serverAddress, onClosedListener)
}
return netConn, err
} }
// RelayAddress returns the address of the permanent relay server. It could change if the network connection is lost. // RelayAddress returns the address of the permanent relay server. It could change if the network connection is lost.
@ -152,7 +169,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
} }
// if connection closed then delete the relay client from the list // if connection closed then delete the relay client from the list
relayClient.SetOnDisconnectListener(func() { relayClient.SetOnDisconnectListener(func() {
m.deleteRelayConn(serverAddress) m.onServerDisconnected(serverAddress)
}) })
rt.relayClient = relayClient rt.relayClient = relayClient
rt.Unlock() rt.Unlock()
@ -164,11 +181,12 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
return conn, nil return conn, nil
} }
func (m *Manager) deleteRelayConn(address string) { func (m *Manager) onServerDisconnected(serverAddress string) {
log.Infof("deleting relay client for %s", address) if serverAddress == m.srvAddress {
m.relayClientsMutex.Lock() m.reconnectGuard.OnDisconnected()
delete(m.relayClients, address) }
m.relayClientsMutex.Unlock()
m.notifyOnDisconnectListeners(serverAddress)
} }
func (m *Manager) isForeignServer(address string) (bool, error) { func (m *Manager) isForeignServer(address string) (bool, error) {
@ -212,8 +230,33 @@ func (m *Manager) cleanUpUnusedRelays() {
go func() { go func() {
_ = rt.relayClient.Close() _ = rt.relayClient.Close()
}() }()
log.Debugf("clean up relay client: %s", addr) log.Debugf("clean up unused relay server connection: %s", addr)
delete(m.relayClients, addr) delete(m.relayClients, addr)
rt.Unlock() rt.Unlock()
} }
} }
func (m *Manager) addListener(serverAddress string, onClosedListener func()) {
m.listenerLock.Lock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
l = make(map[*func()]struct{})
}
l[&onClosedListener] = struct{}{}
m.onDisconnectedListeners[serverAddress] = l
m.listenerLock.Unlock()
}
func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
m.listenerLock.Lock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
return
}
for f := range l {
go (*f)()
}
delete(m.onDisconnectedListeners, serverAddress)
m.listenerLock.Unlock()
}

View File

@ -61,11 +61,11 @@ func TestForeignConn(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to get relay address: %s", err) t.Fatalf("failed to get relay address: %s", err)
} }
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr.String(), idBob) connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr.String(), idBob, nil)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr.String(), idAlice) connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr.String(), idAlice, nil)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -139,7 +139,7 @@ func TestForeginConnClose(t *testing.T) {
mgr := NewManager(mCtx, addr1, idAlice) mgr := NewManager(mCtx, addr1, idAlice)
mgr.Serve() mgr.Serve()
conn, err := mgr.OpenConn(addr2, "anotherpeer") conn, err := mgr.OpenConn(addr2, "anotherpeer", nil)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -198,7 +198,7 @@ func TestForeginAutoClose(t *testing.T) {
mgr.Serve() mgr.Serve()
t.Log("open connection to another peer") t.Log("open connection to another peer")
conn, err := mgr.OpenConn(addr2, "anotherpeer") conn, err := mgr.OpenConn(addr2, "anotherpeer", nil)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -246,7 +246,7 @@ func TestAutoReconnect(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("failed to get relay address: %s", err) t.Errorf("failed to get relay address: %s", err)
} }
conn, err := clientAlice.OpenConn(ra.String(), "bob") conn, err := clientAlice.OpenConn(ra.String(), "bob", nil)
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@ -264,7 +264,7 @@ func TestAutoReconnect(t *testing.T) {
time.Sleep(reconnectingTimeout + 1*time.Second) time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection") log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra.String(), "bob") _, err = clientAlice.OpenConn(ra.String(), "bob", nil)
if err != nil { if err != nil {
t.Errorf("failed to open channel: %s", err) t.Errorf("failed to open channel: %s", err)
} }

View File

@ -32,7 +32,7 @@ func init() {
func waitForExitSignal() { func waitForExitSignal() {
osSigs := make(chan os.Signal, 1) osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM) signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
_ = <-osSigs <-osSigs
} }
func execute(cmd *cobra.Command, args []string) { func execute(cmd *cobra.Command, args []string) {

View File

@ -16,5 +16,5 @@ func HashID(peerID string) ([]byte, string) {
} }
func HashIDToString(idHash []byte) string { func HashIDToString(idHash []byte) string {
return base64.StdEncoding.EncodeToString(idHash[:]) return base64.StdEncoding.EncodeToString(idHash)
} }

View File

@ -80,7 +80,7 @@ func (r *Server) Close() error {
func (r *Server) accept(conn net.Conn) { func (r *Server) accept(conn net.Conn) {
peer, err := handShake(conn) peer, err := handShake(conn)
if err != nil { if err != nil {
log.Errorf("failed to handshake wiht %s: %s", conn.RemoteAddr(), err) log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err)
cErr := conn.Close() cErr := conn.Close()
if cErr != nil { if cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
@ -134,7 +134,6 @@ func (r *Server) accept(conn net.Conn) {
if err != nil { if err != nil {
peer.Log.Errorf("failed to write transport message to: %s", dp.String()) peer.Log.Errorf("failed to write transport message to: %s", dp.String())
} }
return
}() }()
case messages.MsgClose: case messages.MsgClose:
peer.Log.Infof("peer disconnected gracefully") peer.Log.Infof("peer disconnected gracefully")