Refactor the worker callbacks to channel

This commit is contained in:
Zoltán Papp 2024-10-31 21:03:10 +01:00
parent ec5095ba6b
commit b33c83c3f8
3 changed files with 103 additions and 53 deletions

View File

@ -134,36 +134,29 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
statusICE: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(),
} }
rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady,
OnDisconnected: conn.onWorkerRelayStateDisconnected,
}
wFns := WorkerICECallbacks{
OnConnReady: conn.iCEConnectionIsReady,
OnStatusChanged: conn.onWorkerICEStateDisconnected,
}
ctrl := isController(config) ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns) conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay) conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
conn.workerRelay.OnNewOffer(ctx, remoteOfferAnswer)
})
if os.Getenv("NB_FORCE_RELAY") != "true" { if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
conn.workerICE.OnNewOffer(ctx, remoteOfferAnswer)
})
} }
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher) conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
go conn.handshaker.Listen() go conn.handshaker.Listen()
return conn, nil return conn, nil
} }
@ -190,6 +183,7 @@ func (conn *Conn) Open() {
} }
go conn.startHandshakeAndReconnect(conn.ctx) go conn.startHandshakeAndReconnect(conn.ctx)
go conn.listenWorkersEvents()
} }
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
@ -301,7 +295,7 @@ func (conn *Conn) GetKey() string {
} }
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { func (conn *Conn) iCEConnectionIsReady(iceConnInfo ICEConnInfo) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@ -311,7 +305,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready") conn.log.Debugf("ICE connection is ready")
if conn.currentConnPriority > priority { if conn.currentConnPriority > iceConnInfo.ConnPriority {
conn.statusICE.Set(StatusConnected) conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
return return
@ -333,7 +327,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
ep = wgProxy.EndpointAddr() ep = wgProxy.EndpointAddr()
conn.wgProxyICE = wgProxy conn.wgProxyICE = wgProxy
} else { } else {
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String()) directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteIceCandidateEndpoint)
if err != nil { if err != nil {
log.Errorf("failed to resolveUDPaddr") log.Errorf("failed to resolveUDPaddr")
conn.handleConfigurationFailure(err, nil) conn.handleConfigurationFailure(err, nil)
@ -361,14 +355,13 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
conn.currentConnPriority = priority conn.currentConnPriority = iceConnInfo.ConnPriority
conn.statusICE.Set(StatusConnected) conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
} }
// todo review to make sense to handle connecting and disconnected status also? func (conn *Conn) onWorkerICEStateDisconnected() {
func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@ -376,7 +369,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
return return
} }
conn.log.Tracef("ICE connection state changed to %s", newState) conn.log.Tracef("ICE connection state changed to disconnected")
if conn.wgProxyICE != nil { if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil { if err := conn.wgProxyICE.CloseConn(); err != nil {
@ -396,8 +389,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
} }
changed := conn.statusICE.Get() != newState && newState != StatusConnecting changed := conn.statusICE.Get() != stateDisconnected
conn.statusICE.Set(newState) conn.statusICE.Set(stateDisconnected)
conn.guard.SetICEConnDisconnected(changed) conn.guard.SetICEConnDisconnected(changed)
@ -731,6 +724,33 @@ func (conn *Conn) logTraceConnState() {
} }
} }
func (conn *Conn) listenWorkersEvents() {
for {
select {
case e := <-conn.workerRelay.EventChan:
switch e.ConnStatus {
case StatusConnected:
conn.relayConnectionIsReady(e.RelayConnInfo)
case StatusDisconnected:
conn.onWorkerRelayStateDisconnected()
default:
log.Errorf("unexpected relay connection status: %v", e.ConnStatus)
}
case e := <-conn.workerICE.EventChan:
switch e.ConnStatus {
case StatusConnected:
conn.iCEConnectionIsReady(e.ICEConnInfo)
case StatusDisconnected:
conn.onWorkerICEStateDisconnected()
default:
log.Errorf("unexpected ICE connection status: %v", e.ConnStatus)
}
case <-conn.ctx.Done():
return
}
}
}
func isController(config ConnConfig) bool { func isController(config ConnConfig) bool {
return config.LocalKey > config.Key return config.LocalKey > config.Key
} }

View File

@ -19,8 +19,14 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type ICEEvent struct {
ConnStatus ConnStatus
ICEConnInfo ICEConnInfo
}
type ICEConnInfo struct { type ICEConnInfo struct {
RemoteConn net.Conn RemoteConn net.Conn
RemoteAddr net.Addr
RosenpassPubKey []byte RosenpassPubKey []byte
RosenpassAddr string RosenpassAddr string
LocalIceCandidateType string LocalIceCandidateType string
@ -29,14 +35,11 @@ type ICEConnInfo struct {
LocalIceCandidateEndpoint string LocalIceCandidateEndpoint string
Relayed bool Relayed bool
RelayedOnLocal bool RelayedOnLocal bool
} ConnPriority ConnPriority
type WorkerICECallbacks struct {
OnConnReady func(ConnPriority, ICEConnInfo)
OnStatusChanged func(ConnStatus)
} }
type WorkerICE struct { type WorkerICE struct {
EventChan chan ICEEvent
ctx context.Context ctx context.Context
log *log.Entry log *log.Entry
config ConnConfig config ConnConfig
@ -44,7 +47,6 @@ type WorkerICE struct {
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
statusRecorder *Status statusRecorder *Status
hasRelayOnLocally bool hasRelayOnLocally bool
conn WorkerICECallbacks
selectedPriority ConnPriority selectedPriority ConnPriority
@ -59,8 +61,9 @@ type WorkerICE struct {
localPwd string localPwd string
} }
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
w := &WorkerICE{ w := &WorkerICE{
EventChan: make(chan ICEEvent, 2),
ctx: ctx, ctx: ctx,
log: log, log: log,
config: config, config: config,
@ -68,7 +71,6 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal
iFaceDiscover: ifaceDiscover, iFaceDiscover: ifaceDiscover,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally, hasRelayOnLocally: hasRelayOnLocally,
conn: callBacks,
} }
localUfrag, localPwd, err := icemaker.GenerateICECredentials() localUfrag, localPwd, err := icemaker.GenerateICECredentials()
@ -80,7 +82,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal
return w, nil return w, nil
} }
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerICE) OnNewOffer(_ context.Context, remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE") w.log.Debugf("OnNewOffer for ICE")
w.muxAgent.Lock() w.muxAgent.Lock()
@ -133,6 +135,11 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
return return
} }
if pair == nil {
w.log.Errorf("remote address is nil, ICE conn already closed")
return
}
if !isRelayCandidate(pair.Local) { if !isRelayCandidate(pair.Local) {
// dynamically set remote WireGuard port if other side specified a different one from the default one // dynamically set remote WireGuard port if other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort remoteWgPort := iface.DefaultWgPort
@ -154,9 +161,13 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()), RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Relayed: isRelayed(pair), Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local), RelayedOnLocal: isRelayCandidate(pair.Local),
ConnPriority: w.selectedPriority,
} }
w.log.Debugf("on ICE conn read to use ready") w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(w.selectedPriority, ci) select {
case w.EventChan <- ICEEvent{ConnStatus: StatusConnected, ICEConnInfo: ci}:
case <-w.ctx.Done():
}
} }
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@ -216,7 +227,10 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) { err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected { if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected) select {
case w.EventChan <- ICEEvent{ConnStatus: StatusDisconnected}:
case <-w.ctx.Done():
}
w.muxAgent.Lock() w.muxAgent.Lock()
agentCancel() agentCancel()

View File

@ -18,23 +18,23 @@ var (
wgHandshakeOvertime = 30 * time.Second wgHandshakeOvertime = 30 * time.Second
) )
type RelayEvent struct {
ConnStatus ConnStatus
RelayConnInfo RelayConnInfo
}
type RelayConnInfo struct { type RelayConnInfo struct {
relayedConn net.Conn relayedConn net.Conn
rosenpassPubKey []byte rosenpassPubKey []byte
rosenpassAddr string rosenpassAddr string
} }
type WorkerRelayCallbacks struct {
OnConnReady func(RelayConnInfo)
OnDisconnected func()
}
type WorkerRelay struct { type WorkerRelay struct {
EventChan chan RelayEvent
log *log.Entry log *log.Entry
isController bool isController bool
config ConnConfig config ConnConfig
relayManager relayClient.ManagerService relayManager relayClient.ManagerService
callBacks WorkerRelayCallbacks
relayedConn net.Conn relayedConn net.Conn
relayLock sync.Mutex relayLock sync.Mutex
@ -45,18 +45,18 @@ type WorkerRelay struct {
relaySupportedOnRemotePeer atomic.Bool relaySupportedOnRemotePeer atomic.Bool
} }
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
EventChan: make(chan RelayEvent, 2),
log: log, log: log,
isController: ctrl, isController: ctrl,
config: config, config: config,
relayManager: relayManager, relayManager: relayManager,
callBacks: callbacks,
} }
return r return r
} }
func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerRelay) OnNewOffer(ctx context.Context, remoteOfferAnswer *OfferAnswer) {
if !w.isRelaySupported(remoteOfferAnswer) { if !w.isRelaySupported(remoteOfferAnswer) {
w.log.Infof("Relay is not supported by remote peer") w.log.Infof("Relay is not supported by remote peer")
w.relaySupportedOnRemotePeer.Store(false) w.relaySupportedOnRemotePeer.Store(false)
@ -87,7 +87,9 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.relayedConn = relayedConn w.relayedConn = relayedConn
w.relayLock.Unlock() w.relayLock.Unlock()
err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected) err = w.relayManager.AddCloseListener(srv, func() {
w.onRelayMGDisconnected(ctx)
})
if err != nil { if err != nil {
log.Errorf("failed to add close listener: %s", err) log.Errorf("failed to add close listener: %s", err)
_ = relayedConn.Close() _ = relayedConn.Close()
@ -95,11 +97,17 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
} }
w.log.Debugf("peer conn opened via Relay: %s", srv) w.log.Debugf("peer conn opened via Relay: %s", srv)
go w.callBacks.OnConnReady(RelayConnInfo{ select {
case w.EventChan <- RelayEvent{
ConnStatus: StatusConnected,
RelayConnInfo: RelayConnInfo{
relayedConn: relayedConn, relayedConn: relayedConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
rosenpassAddr: remoteOfferAnswer.RosenpassAddr, rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
}) },
}:
case <-ctx.Done():
}
} }
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) { func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
@ -187,7 +195,11 @@ func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.Cancel
w.relayLock.Lock() w.relayLock.Lock()
_ = w.relayedConn.Close() _ = w.relayedConn.Close()
w.relayLock.Unlock() w.relayLock.Unlock()
w.callBacks.OnDisconnected()
select {
case w.EventChan <- RelayEvent{ConnStatus: StatusDisconnected}:
case <-ctx.Done():
}
return return
} }
@ -225,12 +237,16 @@ func (w *WorkerRelay) wgState() (time.Time, error) {
return wgState.LastHandshake, nil return wgState.LastHandshake, nil
} }
func (w *WorkerRelay) onRelayMGDisconnected() { func (w *WorkerRelay) onRelayMGDisconnected(ctx context.Context) {
w.ctxLock.Lock() w.ctxLock.Lock()
defer w.ctxLock.Unlock() defer w.ctxLock.Unlock()
if w.ctxCancelWgWatch != nil { if w.ctxCancelWgWatch != nil {
w.ctxCancelWgWatch() w.ctxCancelWgWatch()
} }
go w.callBacks.OnDisconnected()
select {
case w.EventChan <- RelayEvent{ConnStatus: StatusDisconnected}:
case <-ctx.Done():
}
} }