diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 56b772759..69653dca9 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -134,36 +134,29 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu statusICE: NewAtomicConnStatus(), } - rFns := WorkerRelayCallbacks{ - OnConnReady: conn.relayConnectionIsReady, - OnDisconnected: conn.onWorkerRelayStateDisconnected, - } - - wFns := WorkerICECallbacks{ - OnConnReady: conn.iCEConnectionIsReady, - OnStatusChanged: conn.onWorkerICEStateDisconnected, - } - ctrl := isController(config) - conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns) + conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() - conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) + conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally) if err != nil { return nil, err } conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay) - conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) + conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { + conn.workerRelay.OnNewOffer(ctx, remoteOfferAnswer) + }) if os.Getenv("NB_FORCE_RELAY") != "true" { - conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) + conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { + conn.workerICE.OnNewOffer(ctx, remoteOfferAnswer) + }) } conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher) go conn.handshaker.Listen() - return conn, nil } @@ -190,6 +183,7 @@ func (conn *Conn) Open() { } go conn.startHandshakeAndReconnect(conn.ctx) + go conn.listenWorkersEvents() } func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { @@ -301,7 +295,7 @@ func (conn *Conn) GetKey() string { } // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected -func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { +func (conn *Conn) iCEConnectionIsReady(iceConnInfo ICEConnInfo) { conn.mu.Lock() defer conn.mu.Unlock() @@ -311,7 +305,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon conn.log.Debugf("ICE connection is ready") - if conn.currentConnPriority > priority { + if conn.currentConnPriority > iceConnInfo.ConnPriority { conn.statusICE.Set(StatusConnected) conn.updateIceState(iceConnInfo) return @@ -333,7 +327,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon ep = wgProxy.EndpointAddr() conn.wgProxyICE = wgProxy } else { - directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String()) + directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteIceCandidateEndpoint) if err != nil { log.Errorf("failed to resolveUDPaddr") conn.handleConfigurationFailure(err, nil) @@ -361,14 +355,13 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon return } wgConfigWorkaround() - conn.currentConnPriority = priority + conn.currentConnPriority = iceConnInfo.ConnPriority conn.statusICE.Set(StatusConnected) conn.updateIceState(iceConnInfo) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) } -// todo review to make sense to handle connecting and disconnected status also? -func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { +func (conn *Conn) onWorkerICEStateDisconnected() { conn.mu.Lock() defer conn.mu.Unlock() @@ -376,7 +369,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { return } - conn.log.Tracef("ICE connection state changed to %s", newState) + conn.log.Tracef("ICE connection state changed to disconnected") if conn.wgProxyICE != nil { if err := conn.wgProxyICE.CloseConn(); err != nil { @@ -396,8 +389,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.currentConnPriority = connPriorityRelay } - changed := conn.statusICE.Get() != newState && newState != StatusConnecting - conn.statusICE.Set(newState) + changed := conn.statusICE.Get() != stateDisconnected + conn.statusICE.Set(stateDisconnected) conn.guard.SetICEConnDisconnected(changed) @@ -731,6 +724,33 @@ func (conn *Conn) logTraceConnState() { } } +func (conn *Conn) listenWorkersEvents() { + for { + select { + case e := <-conn.workerRelay.EventChan: + switch e.ConnStatus { + case StatusConnected: + conn.relayConnectionIsReady(e.RelayConnInfo) + case StatusDisconnected: + conn.onWorkerRelayStateDisconnected() + default: + log.Errorf("unexpected relay connection status: %v", e.ConnStatus) + } + case e := <-conn.workerICE.EventChan: + switch e.ConnStatus { + case StatusConnected: + conn.iCEConnectionIsReady(e.ICEConnInfo) + case StatusDisconnected: + conn.onWorkerICEStateDisconnected() + default: + log.Errorf("unexpected ICE connection status: %v", e.ConnStatus) + } + case <-conn.ctx.Done(): + return + } + } +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 55894218d..3cd9e1bce 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -19,8 +19,14 @@ import ( "github.com/netbirdio/netbird/route" ) +type ICEEvent struct { + ConnStatus ConnStatus + ICEConnInfo ICEConnInfo +} + type ICEConnInfo struct { RemoteConn net.Conn + RemoteAddr net.Addr RosenpassPubKey []byte RosenpassAddr string LocalIceCandidateType string @@ -29,14 +35,11 @@ type ICEConnInfo struct { LocalIceCandidateEndpoint string Relayed bool RelayedOnLocal bool -} - -type WorkerICECallbacks struct { - OnConnReady func(ConnPriority, ICEConnInfo) - OnStatusChanged func(ConnStatus) + ConnPriority ConnPriority } type WorkerICE struct { + EventChan chan ICEEvent ctx context.Context log *log.Entry config ConnConfig @@ -44,7 +47,6 @@ type WorkerICE struct { iFaceDiscover stdnet.ExternalIFaceDiscover statusRecorder *Status hasRelayOnLocally bool - conn WorkerICECallbacks selectedPriority ConnPriority @@ -59,8 +61,9 @@ type WorkerICE struct { localPwd string } -func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { +func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { w := &WorkerICE{ + EventChan: make(chan ICEEvent, 2), ctx: ctx, log: log, config: config, @@ -68,7 +71,6 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal iFaceDiscover: ifaceDiscover, statusRecorder: statusRecorder, hasRelayOnLocally: hasRelayOnLocally, - conn: callBacks, } localUfrag, localPwd, err := icemaker.GenerateICECredentials() @@ -80,7 +82,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal return w, nil } -func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { +func (w *WorkerICE) OnNewOffer(_ context.Context, remoteOfferAnswer *OfferAnswer) { w.log.Debugf("OnNewOffer for ICE") w.muxAgent.Lock() @@ -133,6 +135,11 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { return } + if pair == nil { + w.log.Errorf("remote address is nil, ICE conn already closed") + return + } + if !isRelayCandidate(pair.Local) { // dynamically set remote WireGuard port if other side specified a different one from the default one remoteWgPort := iface.DefaultWgPort @@ -154,9 +161,13 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()), Relayed: isRelayed(pair), RelayedOnLocal: isRelayCandidate(pair.Local), + ConnPriority: w.selectedPriority, } w.log.Debugf("on ICE conn read to use ready") - go w.conn.OnConnReady(w.selectedPriority, ci) + select { + case w.EventChan <- ICEEvent{ConnStatus: StatusConnected, ICEConnInfo: ci}: + case <-w.ctx.Done(): + } } // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. @@ -216,7 +227,10 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i err = agent.OnConnectionStateChange(func(state ice.ConnectionState) { w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected { - w.conn.OnStatusChanged(StatusDisconnected) + select { + case w.EventChan <- ICEEvent{ConnStatus: StatusDisconnected}: + case <-w.ctx.Done(): + } w.muxAgent.Lock() agentCancel() diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index c22dcdeda..56dcaa579 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -18,23 +18,23 @@ var ( wgHandshakeOvertime = 30 * time.Second ) +type RelayEvent struct { + ConnStatus ConnStatus + RelayConnInfo RelayConnInfo +} + type RelayConnInfo struct { relayedConn net.Conn rosenpassPubKey []byte rosenpassAddr string } -type WorkerRelayCallbacks struct { - OnConnReady func(RelayConnInfo) - OnDisconnected func() -} - type WorkerRelay struct { + EventChan chan RelayEvent log *log.Entry isController bool config ConnConfig relayManager relayClient.ManagerService - callBacks WorkerRelayCallbacks relayedConn net.Conn relayLock sync.Mutex @@ -45,18 +45,18 @@ type WorkerRelay struct { relaySupportedOnRemotePeer atomic.Bool } -func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { +func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService) *WorkerRelay { r := &WorkerRelay{ + EventChan: make(chan RelayEvent, 2), log: log, isController: ctrl, config: config, relayManager: relayManager, - callBacks: callbacks, } return r } -func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { +func (w *WorkerRelay) OnNewOffer(ctx context.Context, remoteOfferAnswer *OfferAnswer) { if !w.isRelaySupported(remoteOfferAnswer) { w.log.Infof("Relay is not supported by remote peer") w.relaySupportedOnRemotePeer.Store(false) @@ -87,7 +87,9 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.relayedConn = relayedConn w.relayLock.Unlock() - err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected) + err = w.relayManager.AddCloseListener(srv, func() { + w.onRelayMGDisconnected(ctx) + }) if err != nil { log.Errorf("failed to add close listener: %s", err) _ = relayedConn.Close() @@ -95,11 +97,17 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { } w.log.Debugf("peer conn opened via Relay: %s", srv) - go w.callBacks.OnConnReady(RelayConnInfo{ - relayedConn: relayedConn, - rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, - rosenpassAddr: remoteOfferAnswer.RosenpassAddr, - }) + select { + case w.EventChan <- RelayEvent{ + ConnStatus: StatusConnected, + RelayConnInfo: RelayConnInfo{ + relayedConn: relayedConn, + rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, + rosenpassAddr: remoteOfferAnswer.RosenpassAddr, + }, + }: + case <-ctx.Done(): + } } func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) { @@ -187,7 +195,11 @@ func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.Cancel w.relayLock.Lock() _ = w.relayedConn.Close() w.relayLock.Unlock() - w.callBacks.OnDisconnected() + + select { + case w.EventChan <- RelayEvent{ConnStatus: StatusDisconnected}: + case <-ctx.Done(): + } return } @@ -225,12 +237,16 @@ func (w *WorkerRelay) wgState() (time.Time, error) { return wgState.LastHandshake, nil } -func (w *WorkerRelay) onRelayMGDisconnected() { +func (w *WorkerRelay) onRelayMGDisconnected(ctx context.Context) { w.ctxLock.Lock() defer w.ctxLock.Unlock() if w.ctxCancelWgWatch != nil { w.ctxCancelWgWatch() } - go w.callBacks.OnDisconnected() + + select { + case w.EventChan <- RelayEvent{ConnStatus: StatusDisconnected}: + case <-ctx.Done(): + } }