mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 09:47:49 +02:00
[client] Fix race condition while read/write conn status in peer conn (#2607)
This commit is contained in:
parent
5bc601111d
commit
1104c9c048
@ -89,8 +89,8 @@ type Conn struct {
|
|||||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||||
onDisconnected func(remotePeer string, wgIP string)
|
onDisconnected func(remotePeer string, wgIP string)
|
||||||
|
|
||||||
statusRelay ConnStatus
|
statusRelay *AtomicConnStatus
|
||||||
statusICE ConnStatus
|
statusICE *AtomicConnStatus
|
||||||
currentConnPriority ConnPriority
|
currentConnPriority ConnPriority
|
||||||
opened bool // this flag is used to prevent close in case of not opened connection
|
opened bool // this flag is used to prevent close in case of not opened connection
|
||||||
|
|
||||||
@ -131,8 +131,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
signaler: signaler,
|
signaler: signaler,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
allowedIPsIP: allowedIPsIP.String(),
|
allowedIPsIP: allowedIPsIP.String(),
|
||||||
statusRelay: StatusDisconnected,
|
statusRelay: NewAtomicConnStatus(),
|
||||||
statusICE: StatusDisconnected,
|
statusICE: NewAtomicConnStatus(),
|
||||||
iCEDisconnected: make(chan bool, 1),
|
iCEDisconnected: make(chan bool, 1),
|
||||||
relayDisconnected: make(chan bool, 1),
|
relayDisconnected: make(chan bool, 1),
|
||||||
}
|
}
|
||||||
@ -323,11 +323,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||||
if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected {
|
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
|
||||||
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if conn.statusICE == StatusDisconnected {
|
if conn.statusICE.Get() == StatusDisconnected {
|
||||||
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
|
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -419,7 +419,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
|||||||
|
|
||||||
conn.log.Debugf("ICE connection is ready")
|
conn.log.Debugf("ICE connection is ready")
|
||||||
|
|
||||||
conn.statusICE = StatusConnected
|
conn.statusICE.Set(StatusConnected)
|
||||||
|
|
||||||
defer conn.updateIceState(iceConnInfo)
|
defer conn.updateIceState(iceConnInfo)
|
||||||
|
|
||||||
@ -492,8 +492,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
|
|||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = connPriorityRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusICE != newState && newState != StatusConnecting
|
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
|
||||||
conn.statusICE = newState
|
conn.statusICE.Set(newState)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case conn.iCEDisconnected <- changed:
|
case conn.iCEDisconnected <- changed:
|
||||||
@ -522,7 +522,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Debugf("Relay connection is ready to use")
|
conn.log.Debugf("Relay connection is ready to use")
|
||||||
conn.statusRelay = StatusConnected
|
conn.statusRelay.Set(StatusConnected)
|
||||||
|
|
||||||
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
|
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
|
||||||
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
|
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
|
||||||
@ -538,7 +538,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
|
|
||||||
if conn.currentConnPriority > connPriorityRelay {
|
if conn.currentConnPriority > connPriorityRelay {
|
||||||
if conn.statusICE == StatusConnected {
|
if conn.statusICE.Get() == StatusConnected {
|
||||||
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -594,8 +594,8 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
|
|||||||
conn.wgProxyRelay = nil
|
conn.wgProxyRelay = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusRelay != StatusDisconnected
|
changed := conn.statusRelay.Get() != StatusDisconnected
|
||||||
conn.statusRelay = StatusDisconnected
|
conn.statusRelay.Set(StatusDisconnected)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case conn.relayDisconnected <- changed:
|
case conn.relayDisconnected <- changed:
|
||||||
@ -661,8 +661,8 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) setStatusToDisconnected() {
|
func (conn *Conn) setStatusToDisconnected() {
|
||||||
conn.statusRelay = StatusDisconnected
|
conn.statusRelay.Set(StatusDisconnected)
|
||||||
conn.statusICE = StatusDisconnected
|
conn.statusICE.Set(StatusDisconnected)
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@ -706,7 +706,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) isRelayed() bool {
|
func (conn *Conn) isRelayed() bool {
|
||||||
if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) {
|
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -718,11 +718,11 @@ func (conn *Conn) isRelayed() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) evalStatus() ConnStatus {
|
func (conn *Conn) evalStatus() ConnStatus {
|
||||||
if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected {
|
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
|
||||||
return StatusConnected
|
return StatusConnected
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting {
|
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
|
||||||
return StatusConnecting
|
return StatusConnecting
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -733,12 +733,12 @@ func (conn *Conn) isConnected() bool {
|
|||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting {
|
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||||
if conn.statusRelay != StatusConnected {
|
if conn.statusRelay.Get() != StatusConnected {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import log "github.com/sirupsen/logrus"
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// StatusConnected indicate the peer is in connected state
|
// StatusConnected indicate the peer is in connected state
|
||||||
@ -12,7 +16,34 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ConnStatus describe the status of a peer's connection
|
// ConnStatus describe the status of a peer's connection
|
||||||
type ConnStatus int
|
type ConnStatus int32
|
||||||
|
|
||||||
|
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
|
||||||
|
type AtomicConnStatus struct {
|
||||||
|
status atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
|
||||||
|
func NewAtomicConnStatus() *AtomicConnStatus {
|
||||||
|
acs := &AtomicConnStatus{}
|
||||||
|
acs.Set(StatusDisconnected)
|
||||||
|
return acs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the current connection status
|
||||||
|
func (acs *AtomicConnStatus) Get() ConnStatus {
|
||||||
|
return ConnStatus(acs.status.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set updates the connection status
|
||||||
|
func (acs *AtomicConnStatus) Set(status ConnStatus) {
|
||||||
|
acs.status.Store(int32(status))
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the current status
|
||||||
|
func (acs *AtomicConnStatus) String() string {
|
||||||
|
return acs.Get().String()
|
||||||
|
}
|
||||||
|
|
||||||
func (s ConnStatus) String() string {
|
func (s ConnStatus) String() string {
|
||||||
switch s {
|
switch s {
|
||||||
|
@ -158,8 +158,13 @@ func TestConn_Status(t *testing.T) {
|
|||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
t.Run(table.name, func(t *testing.T) {
|
t.Run(table.name, func(t *testing.T) {
|
||||||
conn.statusICE = table.statusIce
|
si := NewAtomicConnStatus()
|
||||||
conn.statusRelay = table.statusRelay
|
si.Set(table.statusIce)
|
||||||
|
conn.statusICE = si
|
||||||
|
|
||||||
|
sr := NewAtomicConnStatus()
|
||||||
|
sr.Set(table.statusRelay)
|
||||||
|
conn.statusRelay = sr
|
||||||
|
|
||||||
got := conn.Status()
|
got := conn.Status()
|
||||||
assert.Equal(t, got, table.want, "they should be equal")
|
assert.Equal(t, got, table.want, "they should be equal")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user