From e26e2c3a75a81c3e13a0d4f3ec08abf4d9a801d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Tue, 18 Jun 2024 17:40:37 +0200 Subject: [PATCH] Add conn status handling and protect agent --- client/internal/peer/conn.go | 105 +++++++++++++++++++++------ client/internal/peer/handshaker.go | 4 +- client/internal/peer/status_test.go | 10 +-- client/internal/peer/worker_ice.go | 59 +++++++++------ client/internal/peer/worker_relay.go | 4 +- 5 files changed, 126 insertions(+), 56 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 1b2538187..268071593 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -80,7 +80,8 @@ type Conn struct { onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onDisconnected func(remotePeer string, wgIP string) - status ConnStatus + statusRelay ConnStatus + statusICE ConnStatus workerICE *WorkerICE workerRelay *WorkerRelay @@ -113,11 +114,12 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu signaler: signaler, allowedIPsIP: allowedIPsIP.String(), handshaker: NewHandshaker(ctx, config, signaler), - status: StatusDisconnected, + statusRelay: StatusDisconnected, + statusICE: StatusDisconnected, closeCh: make(chan struct{}), } - conn.workerICE = NewWorkerICE(ctx, conn.log, config, config.ICEConfig, signaler, iFaceDiscover, statusRecorder, conn.iCEConnectionIsReady, conn.doHandshake) - conn.workerRelay = NewWorkerRelay(ctx, conn.log, relayManager, config, conn.relayConnectionIsReady, conn.doHandshake) + conn.workerICE = NewWorkerICE(ctx, conn.log, config, config.ICEConfig, signaler, iFaceDiscover, statusRecorder, conn.iCEConnectionIsReady, conn.onWorkerICEStateChanged, conn.doHandshake) + conn.workerRelay = NewWorkerRelay(ctx, conn.log, relayManager, config, conn.relayConnectionIsReady, conn.onWorkerRelayStateChanged, conn.doHandshake) return conn, nil } @@ -140,18 +142,6 @@ func (conn *Conn) Open() { conn.log.Warnf("error while updating the state err: %v", err) } - /* - peerState = State{ - PubKey: conn.config.Key, - ConnStatus: StatusConnecting, - ConnStatusUpdate: time.Now(), - Mux: new(sync.RWMutex), - } - err = conn.statusRecorder.UpdatePeerState(peerState) - if err != nil { - log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err) - } - */ relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() if relayIsSupportedLocally { go conn.workerRelay.SetupRelayConnection() @@ -188,15 +178,16 @@ func (conn *Conn) Close() { conn.connID = "" } - if conn.status == StatusConnected && conn.onDisconnected != nil { + if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps) } - conn.status = StatusDisconnected + conn.statusRelay = StatusDisconnected + conn.statusICE = StatusDisconnected peerState := State{ PubKey: conn.config.Key, - ConnStatus: conn.status, + ConnStatus: StatusDisconnected, ConnStatusUpdate: time.Now(), Mux: new(sync.RWMutex), } @@ -214,7 +205,7 @@ func (conn *Conn) Close() { // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // doesn't block, discards the message if connection wasn't ready func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { - conn.log.Debugf("OnRemoteAnswer, status %s", conn.status.String()) + conn.log.Debugf("OnRemoteAnswer, status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay) return conn.handshaker.OnRemoteAnswer(answer) } @@ -242,7 +233,7 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string) } func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { - conn.log.Debugf("OnRemoteOffer, on status %s", conn.status.String()) + conn.log.Debugf("OnRemoteOffer, on status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay) return conn.handshaker.OnRemoteOffer(offer) } @@ -255,13 +246,65 @@ func (conn *Conn) WgConfig() WgConfig { func (conn *Conn) Status() ConnStatus { conn.mu.Lock() defer conn.mu.Unlock() - return conn.status + return conn.evalStatus() } func (conn *Conn) GetKey() string { return conn.config.Key } +func (conn *Conn) onWorkerICEStateChanged(newState ConnStatus) { + conn.mu.Lock() + defer conn.mu.Unlock() + defer func() { + conn.statusICE = newState + }() + + if conn.statusRelay == StatusConnected { + return + } + + if conn.evalStatus() == newState { + return + } + + if newState > conn.statusICE { + peerState := State{ + PubKey: conn.config.Key, + ConnStatus: newState, + ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), + } + _ = conn.statusRecorder.UpdatePeerState(peerState) + } +} + +func (conn *Conn) onWorkerRelayStateChanged(newState ConnStatus) { + conn.mu.Lock() + defer conn.mu.Unlock() + defer func() { + conn.statusRelay = newState + }() + + if conn.statusICE == StatusConnected { + return + } + + if conn.evalStatus() == newState { + return + } + + if newState > conn.statusRelay { + peerState := State{ + PubKey: conn.config.Key, + ConnStatus: newState, + ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), + } + _ = conn.statusRecorder.UpdatePeerState(peerState) + } +} + func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { conn.mu.Lock() defer conn.mu.Unlock() @@ -270,6 +313,8 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { return } + conn.statusRelay = stateConnected + if conn.currentConnType > connPriorityRelay { return } @@ -326,6 +371,8 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon return } + conn.statusICE = stateConnected + if conn.currentConnType > priority { return } @@ -390,8 +437,6 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon } func (conn *Conn) updateStatus(peerState State, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) { - conn.status = StatusConnected - peerState.PubKey = conn.config.Key peerState.ConnStatus = StatusConnected peerState.ConnStatusUpdate = time.Now() @@ -434,6 +479,18 @@ func (conn *Conn) doHandshake() (*OfferAnswer, error) { }) } +func (conn *Conn) evalStatus() ConnStatus { + if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected { + return StatusConnected + } + + if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting { + return StatusConnecting + } + + return StatusDisconnected +} + func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index e10705f63..2bc42686f 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -116,7 +116,7 @@ func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool { case h.remoteOffersCh <- offer: return true default: - log.Debugf("OnRemoteOffer skipping message from peer %s on status %s because is not ready", h.config.Key, conn.status.String()) + log.Debugf("OnRemoteOffer skipping message from peer %s because is not ready", h.config.Key) // connection might not be ready yet to receive so we ignore the message return false } @@ -130,7 +130,7 @@ func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool { return true default: // connection might not be ready yet to receive so we ignore the message - log.Debugf("OnRemoteAnswer skipping message from peer %s on status %s because is not ready", h.config.Key, conn.status.String()) + log.Debugf("OnRemoteAnswer skipping message from peer %s because is not ready", h.config.Key) return false } } diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index a4a6e6081..1d283433b 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -2,8 +2,8 @@ package peer import ( "errors" - "testing" "sync" + "testing" "github.com/stretchr/testify/assert" ) @@ -43,7 +43,7 @@ func TestUpdatePeerState(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, - Mux: new(sync.RWMutex), + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -64,7 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, - Mux: new(sync.RWMutex), + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -83,7 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, - Mux: new(sync.RWMutex), + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -108,7 +108,7 @@ func TestRemovePeer(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, - Mux: new(sync.RWMutex), + Mux: new(sync.RWMutex), } status.peers[key] = peerState diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 70668b493..33362e86f 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -8,6 +8,7 @@ import ( "net" "net/netip" "runtime" + "sync" "sync/atomic" "time" @@ -59,36 +60,39 @@ type ICEConnInfo struct { } type WorkerICE struct { - ctx context.Context - log *log.Entry - config ConnConfig - configICE ICEConfig - signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover - statusRecorder *Status - onICEConnReady OnICEConnReadyCallback - doHandshakeFn DoHandshake + ctx context.Context + log *log.Entry + config ConnConfig + configICE ICEConfig + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + statusRecorder *Status + onICEConnReady OnICEConnReadyCallback + onStatusChanged func(ConnStatus) + doHandshakeFn DoHandshake - connPriority ConnPriority + selectedPriority ConnPriority - agent *ice.Agent + agent *ice.Agent + muxAgent sync.RWMutex StunTurn []*stun.URI sentExtraSrflx bool } -func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, configICE ICEConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, onICEConnReady OnICEConnReadyCallback, doHandshakeFn DoHandshake) *WorkerICE { +func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, configICE ICEConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, onICEConnReady OnICEConnReadyCallback, onStatusChanged func(ConnStatus), doHandshakeFn DoHandshake) *WorkerICE { cice := &WorkerICE{ - ctx: ctx, - log: log, - config: config, - configICE: configICE, - signaler: signaler, - iFaceDiscover: ifaceDiscover, - statusRecorder: statusRecorder, - onICEConnReady: onICEConnReady, - doHandshakeFn: doHandshakeFn, + ctx: ctx, + log: log, + config: config, + configICE: configICE, + signaler: signaler, + iFaceDiscover: ifaceDiscover, + statusRecorder: statusRecorder, + onICEConnReady: onICEConnReady, + onStatusChanged: onStatusChanged, + doHandshakeFn: doHandshakeFn, } return cice } @@ -103,6 +107,8 @@ func (w *WorkerICE) SetupICEConnection(hasRelayOnLocally bool) { return } + w.onStatusChanged(StatusConnecting) + remoteOfferAnswer, err := w.doHandshakeFn() if err != nil { if errors.Is(err, ErrSignalIsNotReady) { @@ -113,10 +119,10 @@ func (w *WorkerICE) SetupICEConnection(hasRelayOnLocally bool) { var preferredCandidateTypes []ice.CandidateType if hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { - w.connPriority = connPriorityICEP2P + w.selectedPriority = connPriorityICEP2P preferredCandidateTypes = candidateTypesP2P() } else { - w.connPriority = connPriorityICETurn + w.selectedPriority = connPriorityICETurn preferredCandidateTypes = candidateTypes() } @@ -126,7 +132,9 @@ func (w *WorkerICE) SetupICEConnection(hasRelayOnLocally bool) { ctxCancel() continue } + w.muxAgent.Lock() w.agent = agent + w.muxAgent.Unlock() err = w.agent.GatherCandidates() if err != nil { @@ -172,16 +180,19 @@ func (w *WorkerICE) SetupICEConnection(hasRelayOnLocally bool) { Relayed: isRelayed(pair), RelayedOnLocal: isRelayCandidate(pair.Local), } - go w.onICEConnReady(w.connPriority, ci) + go w.onICEConnReady(w.selectedPriority, ci) <-ctx.Done() ctxCancel() _ = w.agent.Close() + w.onStatusChanged(StatusDisconnected) } } // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { + w.muxAgent.RLocker() + defer w.muxAgent.RUnlock() w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String()) if w.agent == nil { return diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 762c69b3e..078fe84da 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -26,16 +26,18 @@ type WorkerRelay struct { relayManager *relayClient.Manager config ConnConfig onRelayConnReadyFN OnRelayReadyCallback + onStatusChanged func(ConnStatus) doHandshakeFn DoHandshake } -func NewWorkerRelay(ctx context.Context, log *log.Entry, relayManager *relayClient.Manager, config ConnConfig, onRelayConnReadyFN OnRelayReadyCallback, doHandshakeFn DoHandshake) *WorkerRelay { +func NewWorkerRelay(ctx context.Context, log *log.Entry, relayManager *relayClient.Manager, config ConnConfig, onRelayConnReadyFN OnRelayReadyCallback, onStatusChanged func(ConnStatus), doHandshakeFn DoHandshake) *WorkerRelay { return &WorkerRelay{ ctx: ctx, log: log, relayManager: relayManager, config: config, onRelayConnReadyFN: onRelayConnReadyFN, + onStatusChanged: onStatusChanged, doHandshakeFn: doHandshakeFn, } }