[client] Fix controller re-connection (#2758)

Rethink the peer reconnection implementation
This commit is contained in:
Zoltan Papp 2024-10-24 11:43:14 +02:00 committed by GitHub
parent 869537c951
commit 4e918e55ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 813 additions and 523 deletions

View File

@ -79,9 +79,6 @@ jobs:
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
- name: Generate Shared Sock Test bin - name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
@ -98,7 +95,7 @@ jobs:
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin - name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/... run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
- run: chmod +x *testing.bin - run: chmod +x *testing.bin
@ -106,7 +103,7 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker - name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@ -143,7 +143,6 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
conn, ok := b.endpoints[ep.DstIP()] conn, ok := b.endpoints[ep.DstIP()]
b.endpointsMu.Unlock() b.endpointsMu.Unlock()
if !ok { if !ok {
log.Infof("failed to find endpoint for %s", ep.DstIP())
return b.StdNetBind.Send(bufs, ep) return b.StdNetBind.Send(bufs, ep)
} }

View File

@ -5,9 +5,9 @@ import (
"net" "net"
) )
const ( var (
portRangeStart = 3128 portRangeStart = 3128
portRangeEnd = 3228 portRangeEnd = portRangeStart + 100
) )
type portLookup struct { type portLookup struct {

View File

@ -17,6 +17,9 @@ func Test_portLookup_searchFreePort(t *testing.T) {
func Test_portLookup_on_allocated(t *testing.T) { func Test_portLookup_on_allocated(t *testing.T) {
pl := portLookup{} pl := portLookup{}
portRangeStart = 4128
portRangeEnd = portRangeStart + 100
allocatedPort, err := allocatePort(portRangeStart) allocatedPort, err := allocatePort(portRangeStart)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -22,9 +22,11 @@ func NewKernelFactory(wgPort int) *KernelFactory {
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f return f
} }
log.Infof("WireGuard Proxy Factory will produce eBPF proxy")
f.ebpfProxy = ebpfProxy f.ebpfProxy = ebpfProxy
return f return f
} }

View File

@ -1,6 +1,8 @@
package wgproxy package wgproxy
import ( import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
) )
@ -10,6 +12,7 @@ type KernelFactory struct {
} }
func NewKernelFactory(wgPort int) *KernelFactory { func NewKernelFactory(wgPort int) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{ f := &KernelFactory{
wgPort: wgPort, wgPort: wgPort,
} }

View File

@ -1,6 +1,8 @@
package wgproxy package wgproxy
import ( import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
) )
@ -10,6 +12,7 @@ type USPFactory struct {
} }
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{ f := &USPFactory{
bind: iceBind, bind: iceBind,
} }

View File

@ -2,14 +2,16 @@ package udp
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors" cerrors "github.com/netbirdio/netbird/client/errors"
) )
// WGUDPProxy proxies // WGUDPProxy proxies
@ -121,7 +123,7 @@ func (p *WGUDPProxy) close() error {
if err := p.localConn.Close(); err != nil { if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
} }
return errors.FormatErrorOrNil(result) return cerrors.FormatErrorOrNil(result)
} }
// proxyToRemote proxies from Wireguard to the RemoteKey // proxyToRemote proxies from Wireguard to the RemoteKey
@ -160,18 +162,16 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
defer func() { defer func() {
if err := p.close(); err != nil { if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err) if !errors.Is(err, io.EOF) {
log.Warnf("error in proxy to local loop: %s", err)
}
} }
}() }()
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
n, err := p.remoteConn.Read(buf) n, err := p.remoteConnRead(ctx, buf)
if err != nil { if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }
@ -193,3 +193,15 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
} }
} }
} }
func (p *WGUDPProxy) remoteConnRead(ctx context.Context, buf []byte) (n int, err error) {
n, err = p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.LocalAddr(), err)
return
}
return
}

View File

@ -30,6 +30,8 @@ import (
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
@ -168,6 +170,7 @@ type Engine struct {
relayManager *relayClient.Manager relayManager *relayClient.Manager
stateManager *statemanager.Manager stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@ -263,6 +266,10 @@ func (e *Engine) Stop() error {
e.routeManager.Stop(e.stateManager) e.routeManager.Stop(e.stateManager)
} }
if e.srWatcher != nil {
e.srWatcher.Close()
}
err := e.removeAllPeers() err := e.removeAllPeers()
if err != nil { if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
@ -389,6 +396,18 @@ func (e *Engine) Start() error {
return fmt.Errorf("initialize dns server: %w", err) return fmt.Errorf("initialize dns server: %w", err)
} }
iceCfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start()
e.receiveSignalEvents() e.receiveSignalEvents()
e.receiveManagementEvents() e.receiveManagementEvents()
e.receiveProbeEvents() e.receiveProbeEvents()
@ -971,7 +990,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
RosenpassPubKey: e.getRosenpassPubKey(), RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(), RosenpassAddr: e.getRosenpassAddr(),
ICEConfig: peer.ICEConfig{ ICEConfig: icemaker.Config{
StunTurn: &e.stunTurn, StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList, InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery, DisableIPv6Discovery: e.config.DisableIPv6Discovery,
@ -981,7 +1000,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
}, },
} }
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -29,6 +29,8 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@ -258,6 +260,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
} }
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
type testCase struct { type testCase struct {
name string name string

View File

@ -10,7 +10,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/cenkalti/backoff/v4"
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@ -18,6 +17,8 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@ -32,8 +33,6 @@ const (
connPriorityRelay ConnPriority = 1 connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 1 connPriorityICETurn ConnPriority = 1
connPriorityICEP2P ConnPriority = 2 connPriorityICEP2P ConnPriority = 2
reconnectMaxElapsedTime = 30 * time.Minute
) )
type WgConfig struct { type WgConfig struct {
@ -63,7 +62,7 @@ type ConnConfig struct {
RosenpassAddr string RosenpassAddr string
// ICEConfig ICE protocol configuration // ICEConfig ICE protocol configuration
ICEConfig ICEConfig ICEConfig icemaker.Config
} }
type WorkerCallbacks struct { type WorkerCallbacks struct {
@ -106,16 +105,12 @@ type Conn struct {
wgProxyICE wgproxy.Proxy wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy wgProxyRelay wgproxy.Proxy
// for reconnection operations guard *guard.Guard
iCEDisconnected chan bool
relayDisconnected chan bool
connMonitor *ConnMonitor
reconnectCh <-chan struct{}
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps)
if err != nil { if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err) log.Errorf("failed to parse allowedIPS: %v", err)
@ -126,29 +121,19 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
connLog := log.WithField("peer", config.Key) connLog := log.WithField("peer", config.Key)
var conn = &Conn{ var conn = &Conn{
log: connLog, log: connLog,
ctx: ctx, ctx: ctx,
ctxCancel: ctxCancel, ctxCancel: ctxCancel,
config: config, config: config,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
signaler: signaler, signaler: signaler,
relayManager: relayManager, relayManager: relayManager,
allowedIP: allowedIP, allowedIP: allowedIP,
allowedNet: allowedNet.String(), allowedNet: allowedNet.String(),
statusRelay: NewAtomicConnStatus(), statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(),
iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1),
} }
conn.connMonitor, conn.reconnectCh = NewConnMonitor(
signaler,
iFaceDiscover,
config,
conn.relayDisconnected,
conn.iCEDisconnected,
)
rFns := WorkerRelayCallbacks{ rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady, OnConnReady: conn.relayConnectionIsReady,
OnDisconnected: conn.onWorkerRelayStateDisconnected, OnDisconnected: conn.onWorkerRelayStateDisconnected,
@ -159,7 +144,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
OnStatusChanged: conn.onWorkerICEStateDisconnected, OnStatusChanged: conn.onWorkerICEStateDisconnected,
} }
conn.workerRelay = NewWorkerRelay(connLog, config, relayManager, rFns) ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
@ -174,6 +160,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
} }
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
go conn.handshaker.Listen() go conn.handshaker.Listen()
return conn, nil return conn, nil
@ -184,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
// be used. // be used.
func (conn *Conn) Open() { func (conn *Conn) Open() {
conn.log.Debugf("open connection to peer") conn.log.Debugf("open connection to peer")
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.opened = true conn.opened = true
@ -200,24 +189,19 @@ func (conn *Conn) Open() {
conn.log.Warnf("error while updating the state err: %v", err) conn.log.Warnf("error while updating the state err: %v", err)
} }
go conn.startHandshakeAndReconnect() go conn.startHandshakeAndReconnect(conn.ctx)
} }
func (conn *Conn) startHandshakeAndReconnect() { func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
conn.waitInitialRandomSleepTime() conn.waitInitialRandomSleepTime(ctx)
err := conn.handshaker.sendOffer() err := conn.handshaker.sendOffer()
if err != nil { if err != nil {
conn.log.Errorf("failed to send initial offer: %v", err) conn.log.Errorf("failed to send initial offer: %v", err)
} }
go conn.connMonitor.Start(conn.ctx) go conn.guard.Start(ctx)
go conn.listenGuardEvent(ctx)
if conn.workerRelay.IsController() {
conn.reconnectLoopWithRetry()
} else {
conn.reconnectLoopForOnDisconnectedEvent()
}
} }
// Close closes this peer Conn issuing a close event to the Conn closeCh // Close closes this peer Conn issuing a close event to the Conn closeCh
@ -316,104 +300,6 @@ func (conn *Conn) GetKey() string {
return conn.config.Key return conn.config.Key
} }
func (conn *Conn) reconnectLoopWithRetry() {
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
select {
case <-conn.ctx.Done():
return
case <-time.After(3 * time.Second):
}
ticker := conn.prepareExponentTicker()
defer ticker.Stop()
time.Sleep(1 * time.Second)
for {
select {
case t := <-ticker.C:
if t.IsZero() {
// in case if the ticker has been canceled by context then avoid the temporary loop
return
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
}
} else {
if conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
}
}
// checks if there is peer connection is established via relay or ice
if conn.isConnected() {
continue
}
err := conn.handshaker.sendOffer()
if err != nil {
conn.log.Errorf("failed to do handshake: %v", err)
}
case <-conn.reconnectCh:
ticker.Stop()
ticker = conn.prepareExponentTicker()
case <-conn.ctx.Done():
conn.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: conn.config.Timeout,
MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, conn.ctx)
ticker := backoff.NewTicker(bo)
<-ticker.C // consume the initial tick what is happening right after the ticker has been created
return ticker
}
// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not
// mean that to switch to it. We always force to use the higher priority connection.
func (conn *Conn) reconnectLoopForOnDisconnectedEvent() {
for {
select {
case changed := <-conn.relayDisconnected:
if !changed {
continue
}
conn.log.Debugf("Relay state changed, try to send new offer")
case changed := <-conn.iCEDisconnected:
if !changed {
continue
}
conn.log.Debugf("ICE state changed, try to send new offer")
case <-conn.ctx.Done():
conn.log.Debugf("context is done, stop reconnect loop")
return
}
err := conn.handshaker.SendOffer()
if err != nil {
conn.log.Errorf("failed to do handshake: %v", err)
}
}
}
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
conn.mu.Lock() conn.mu.Lock()
@ -513,7 +399,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
changed := conn.statusICE.Get() != newState && newState != StatusConnecting changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState) conn.statusICE.Set(newState)
conn.notifyReconnectLoopICEDisconnected(changed) conn.guard.SetICEConnDisconnected(changed)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@ -604,7 +490,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
changed := conn.statusRelay.Get() != StatusDisconnected changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.Set(StatusDisconnected)
conn.notifyReconnectLoopRelayDisconnected(changed) conn.guard.SetRelayedConnDisconnected(changed)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@ -617,6 +503,20 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
} }
} }
func (conn *Conn) listenGuardEvent(ctx context.Context) {
for {
select {
case <-conn.guard.Reconnect:
conn.log.Debugf("send offer to peer")
if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Errorf("failed to send offer: %v", err)
}
case <-ctx.Done():
return
}
}
}
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error { func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
return conn.config.WgConfig.WgInterface.UpdatePeer( return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey, conn.config.WgConfig.RemoteKey,
@ -693,7 +593,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
} }
} }
func (conn *Conn) waitInitialRandomSleepTime() { func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
minWait := 100 minWait := 100
maxWait := 800 maxWait := 800
duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond
@ -702,7 +602,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
defer timeout.Stop() defer timeout.Stop()
select { select {
case <-conn.ctx.Done(): case <-ctx.Done():
case <-timeout.C: case <-timeout.C:
} }
} }
@ -731,11 +631,17 @@ func (conn *Conn) evalStatus() ConnStatus {
return StatusDisconnected return StatusDisconnected
} }
func (conn *Conn) isConnected() bool { func (conn *Conn) isConnectedOnAllWay() (connected bool) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting { defer func() {
if !connected {
conn.logTraceConnState()
}
}()
if conn.statusICE.Get() == StatusDisconnected {
return false return false
} }
@ -805,20 +711,6 @@ func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
} }
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
select {
case conn.relayDisconnected <- changed:
default:
}
}
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
select {
case conn.iCEDisconnected <- changed:
default:
}
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err) conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil { if wgProxy != nil {
@ -831,6 +723,18 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
} }
} }
func (conn *Conn) logTraceConnState() {
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
} else {
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
}
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil return remoteRosenpassPubKey != nil
} }

View File

@ -1,212 +0,0 @@
package peer
import (
"context"
"fmt"
"sync"
"time"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
const (
signalerMonitorPeriod = 5 * time.Second
candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
)
type ConnMonitor struct {
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
config ConnConfig
relayDisconnected chan bool
iCEDisconnected chan bool
reconnectCh chan struct{}
currentCandidates []ice.Candidate
candidatesMu sync.Mutex
}
func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) {
reconnectCh := make(chan struct{}, 1)
cm := &ConnMonitor{
signaler: signaler,
iFaceDiscover: iFaceDiscover,
config: config,
relayDisconnected: relayDisconnected,
iCEDisconnected: iCEDisconnected,
reconnectCh: reconnectCh,
}
return cm, reconnectCh
}
func (cm *ConnMonitor) Start(ctx context.Context) {
signalerReady := make(chan struct{}, 1)
go cm.monitorSignalerReady(ctx, signalerReady)
localCandidatesChanged := make(chan struct{}, 1)
go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged)
for {
select {
case changed := <-cm.relayDisconnected:
if !changed {
continue
}
log.Debugf("Relay state changed, triggering reconnect")
cm.triggerReconnect()
case changed := <-cm.iCEDisconnected:
if !changed {
continue
}
log.Debugf("ICE state changed, triggering reconnect")
cm.triggerReconnect()
case <-signalerReady:
log.Debugf("Signaler became ready, triggering reconnect")
cm.triggerReconnect()
case <-localCandidatesChanged:
log.Debugf("Local candidates changed, triggering reconnect")
cm.triggerReconnect()
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) {
if cm.signaler == nil {
return
}
ticker := time.NewTicker(signalerMonitorPeriod)
defer ticker.Stop()
lastReady := true
for {
select {
case <-ticker.C:
currentReady := cm.signaler.Ready()
if !lastReady && currentReady {
select {
case signalerReady <- struct{}{}:
default:
}
}
lastReady = currentReady
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) {
ufrag, pwd, err := generateICECredentials()
if err != nil {
log.Warnf("Failed to generate ICE credentials: %v", err)
return
}
ticker := time.NewTicker(candidatesMonitorPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil {
log.Warnf("Failed to handle candidate tick: %v", err)
}
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error {
log.Debugf("Gathering ICE candidates")
transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}
agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return fmt.Errorf("create ICE agent: %w", err)
}
defer func() {
if err := agent.Close(); err != nil {
log.Warnf("Failed to close ICE agent: %v", err)
}
}()
gatherDone := make(chan struct{})
err = agent.OnCandidate(func(c ice.Candidate) {
log.Tracef("Got candidate: %v", c)
if c == nil {
close(gatherDone)
}
})
if err != nil {
return fmt.Errorf("set ICE candidate handler: %w", err)
}
if err := agent.GatherCandidates(); err != nil {
return fmt.Errorf("gather ICE candidates: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
defer cancel()
select {
case <-ctx.Done():
return fmt.Errorf("wait for gathering: %w", ctx.Err())
case <-gatherDone:
}
candidates, err := agent.GetLocalCandidates()
if err != nil {
return fmt.Errorf("get local candidates: %w", err)
}
log.Tracef("Got candidates: %v", candidates)
if changed := cm.updateCandidates(candidates); changed {
select {
case localCandidatesChanged <- struct{}{}:
default:
}
}
return nil
}
func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
cm.candidatesMu.Lock()
defer cm.candidatesMu.Unlock()
if len(cm.currentCandidates) != len(newCandidates) {
cm.currentCandidates = newCandidates
return true
}
for i, candidate := range cm.currentCandidates {
if candidate.Address() != newCandidates[i].Address() {
cm.currentCandidates = newCandidates
return true
}
}
return false
}
func (cm *ConnMonitor) triggerReconnect() {
select {
case cm.reconnectCh <- struct{}{}:
default:
}
}

View File

@ -10,6 +10,8 @@ import (
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer/guard"
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -19,7 +21,7 @@ var connConf = ConnConfig{
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
Timeout: time.Second, Timeout: time.Second,
LocalWgPort: 51820, LocalWgPort: 51820,
ICEConfig: ICEConfig{ ICEConfig: ice.Config{
InterfaceBlackList: nil, InterfaceBlackList: nil,
}, },
} }
@ -43,7 +45,8 @@ func TestNewConn_interfaceFilter(t *testing.T) {
} }
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher)
if err != nil { if err != nil {
return return
} }
@ -54,7 +57,8 @@ func TestConn_GetKey(t *testing.T) {
} }
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
if err != nil { if err != nil {
return return
} }
@ -87,7 +91,8 @@ func TestConn_OnRemoteOffer(t *testing.T) {
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
if err != nil { if err != nil {
return return
} }
@ -119,7 +124,8 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
if err != nil { if err != nil {
return return
} }

View File

@ -0,0 +1,194 @@
package guard
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
)
const (
reconnectMaxElapsedTime = 30 * time.Minute
)
type isConnectedFunc func() bool
// Guard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues.
// Watch these events:
// - Relay client reconnected to home server
// - Signal server connection state changed
// - ICE connection disconnected
// - Relayed connection disconnected
// - ICE candidate changes
type Guard struct {
Reconnect chan struct{}
log *log.Entry
isController bool
isConnectedOnAllWay isConnectedFunc
timeout time.Duration
srWatcher *SRWatcher
relayedConnDisconnected chan bool
iCEConnDisconnected chan bool
}
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{
Reconnect: make(chan struct{}, 1),
log: log,
isController: isController,
isConnectedOnAllWay: isConnectedFn,
timeout: timeout,
srWatcher: srWatcher,
relayedConnDisconnected: make(chan bool, 1),
iCEConnDisconnected: make(chan bool, 1),
}
}
func (g *Guard) Start(ctx context.Context) {
if g.isController {
g.reconnectLoopWithRetry(ctx)
} else {
g.listenForDisconnectEvents(ctx)
}
}
func (g *Guard) SetRelayedConnDisconnected(changed bool) {
select {
case g.relayedConnDisconnected <- changed:
default:
}
}
func (g *Guard) SetICEConnDisconnected(changed bool) {
select {
case g.iCEConnDisconnected <- changed:
default:
}
}
// reconnectLoopWithRetry periodically check (max 30 min) the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
waitForInitialConnectionTry(ctx)
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
ticker := g.prepareExponentTicker(ctx)
defer ticker.Stop()
tickerChannel := ticker.C
g.log.Infof("start reconnect loop...")
for {
select {
case t := <-tickerChannel:
if t.IsZero() {
g.log.Infof("retry timed out, stop periodic offer sending")
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
tickerChannel = make(<-chan time.Time)
continue
}
if !g.isConnectedOnAllWay() {
g.triggerOfferSending()
}
case changed := <-g.relayedConnDisconnected:
if !changed {
continue
}
g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case changed := <-g.iCEConnDisconnected:
if !changed {
continue
}
g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-srReconnectedChan:
g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not
// mean that to switch to it. We always force to use the higher priority connection.
func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
g.log.Infof("start listen for reconnect events...")
for {
select {
case changed := <-g.relayedConnDisconnected:
if !changed {
continue
}
g.log.Debugf("Relay connection changed, triggering reconnect")
g.triggerOfferSending()
case changed := <-g.iCEConnDisconnected:
if !changed {
continue
}
g.log.Debugf("ICE state changed, try to send new offer")
g.triggerOfferSending()
case <-srReconnectedChan:
g.triggerOfferSending()
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
ticker := backoff.NewTicker(bo)
<-ticker.C // consume the initial tick what is happening right after the ticker has been created
return ticker
}
func (g *Guard) triggerOfferSending() {
select {
case g.Reconnect <- struct{}{}:
default:
}
}
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
func waitForInitialConnectionTry(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-time.After(3 * time.Second):
}
}

View File

@ -0,0 +1,135 @@
package guard
import (
"context"
"fmt"
"sync"
"time"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
const (
candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
)
type ICEMonitor struct {
ReconnectCh chan struct{}
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig icemaker.Config
currentCandidates []ice.Candidate
candidatesMu sync.Mutex
}
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
cm := &ICEMonitor{
ReconnectCh: make(chan struct{}, 1),
iFaceDiscover: iFaceDiscover,
iceConfig: config,
}
return cm
}
func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
ufrag, pwd, err := icemaker.GenerateICECredentials()
if err != nil {
log.Warnf("Failed to generate ICE credentials: %v", err)
return
}
ticker := time.NewTicker(candidatesMonitorPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
changed, err := cm.handleCandidateTick(ctx, ufrag, pwd)
if err != nil {
log.Warnf("Failed to check ICE changes: %v", err)
continue
}
if changed {
onChanged()
}
case <-ctx.Done():
return
}
}
}
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
log.Debugf("Gathering ICE candidates")
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return false, fmt.Errorf("create ICE agent: %w", err)
}
defer func() {
if err := agent.Close(); err != nil {
log.Warnf("Failed to close ICE agent: %v", err)
}
}()
gatherDone := make(chan struct{})
err = agent.OnCandidate(func(c ice.Candidate) {
log.Tracef("Got candidate: %v", c)
if c == nil {
close(gatherDone)
}
})
if err != nil {
return false, fmt.Errorf("set ICE candidate handler: %w", err)
}
if err := agent.GatherCandidates(); err != nil {
return false, fmt.Errorf("gather ICE candidates: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
defer cancel()
select {
case <-ctx.Done():
return false, fmt.Errorf("wait for gathering timed out")
case <-gatherDone:
}
candidates, err := agent.GetLocalCandidates()
if err != nil {
return false, fmt.Errorf("get local candidates: %w", err)
}
log.Tracef("Got candidates: %v", candidates)
return cm.updateCandidates(candidates), nil
}
func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
cm.candidatesMu.Lock()
defer cm.candidatesMu.Unlock()
if len(cm.currentCandidates) != len(newCandidates) {
cm.currentCandidates = newCandidates
return true
}
for i, candidate := range cm.currentCandidates {
if candidate.Address() != newCandidates[i].Address() {
cm.currentCandidates = newCandidates
return true
}
}
return false
}
func candidateTypesP2P() []ice.CandidateType {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}

View File

@ -0,0 +1,119 @@
package guard
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
type chNotifier interface {
SetOnReconnectedListener(func())
Ready() bool
}
type SRWatcher struct {
signalClient chNotifier
relayManager chNotifier
listeners map[chan struct{}]struct{}
mu sync.Mutex
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
cancelIceMonitor context.CancelFunc
}
// NewSRWatcher creates a new SRWatcher. This watcher will notify the listeners when the ICE candidates change or the
// Relay connection is reconnected or the Signal client reconnected.
func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscover stdnet.ExternalIFaceDiscover, iceConfig ice.Config) *SRWatcher {
srw := &SRWatcher{
signalClient: signalClient,
relayManager: relayManager,
iFaceDiscover: iFaceDiscover,
iceConfig: iceConfig,
listeners: make(map[chan struct{}]struct{}),
}
return srw
}
func (w *SRWatcher) Start() {
w.mu.Lock()
defer w.mu.Unlock()
if w.cancelIceMonitor != nil {
return
}
ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
go iceMonitor.Start(ctx, w.onICEChanged)
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)
}
func (w *SRWatcher) Close() {
w.mu.Lock()
defer w.mu.Unlock()
if w.cancelIceMonitor == nil {
return
}
w.cancelIceMonitor()
w.signalClient.SetOnReconnectedListener(nil)
w.relayManager.SetOnReconnectedListener(nil)
}
func (w *SRWatcher) NewListener() chan struct{} {
w.mu.Lock()
defer w.mu.Unlock()
listenerChan := make(chan struct{}, 1)
w.listeners[listenerChan] = struct{}{}
return listenerChan
}
func (w *SRWatcher) RemoveListener(listenerChan chan struct{}) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.listeners, listenerChan)
close(listenerChan)
}
func (w *SRWatcher) onICEChanged() {
if !w.signalClient.Ready() {
return
}
log.Infof("network changes detected by ICE agent")
w.notify()
}
func (w *SRWatcher) onReconnected() {
if !w.signalClient.Ready() {
return
}
if !w.relayManager.Ready() {
return
}
log.Infof("reconnected to Signal or Relay server")
w.notify()
}
func (w *SRWatcher) notify() {
w.mu.Lock()
defer w.mu.Unlock()
for listener := range w.listeners {
select {
case listener <- struct{}{}:
default:
}
}
}

View File

@ -0,0 +1,89 @@
package ice
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/pion/ice/v3"
"github.com/pion/randutil"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"runtime"
"time"
)
const (
lenUFrag = 16
lenPwd = 32
runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
)
var (
failedTimeout = 6 * time.Second
)
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: config.StunTurn.Load().([]*stun.URI),
CandidateTypes: candidateTypes,
InterfaceFilter: stdnet.InterfaceFilter(config.InterfaceBlackList),
UDPMux: config.UDPMux,
UDPMuxSrflx: config.UDPMuxSrflx,
NAT1To1IPs: config.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: ufrag,
LocalPwd: pwd,
}
if config.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
return ice.NewAgent(agentConfig)
}
func GenerateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {
return "", "", err
}
pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha)
if err != nil {
return "", "", err
}
return ufrag, pwd, nil
}
func CandidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
func CandidateTypesP2P() []ice.CandidateType {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}

View File

@ -0,0 +1,22 @@
package ice
import (
"sync/atomic"
"github.com/pion/ice/v3"
)
type Config struct {
// StunTurn is a list of STUN and TURN URLs
StunTurn *atomic.Value // []*stun.URI
// InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
// (e.g. if eth0 is in the list, host candidate of this interface won't be used)
InterfaceBlackList []string
DisableIPv6Discovery bool
UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux
NATExternalIPs []string
}

View File

@ -1,4 +1,4 @@
package peer package ice
import ( import (
"os" "os"
@ -10,12 +10,19 @@ import (
) )
const ( const (
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC" envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC"
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
msgWarnInvalidValue = "invalid value %s set for %s, using default %v"
) )
func hasICEForceRelayConn() bool {
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
return strings.ToLower(disconnectedTimeoutEnv) == "true"
}
func iceKeepAlive() time.Duration { func iceKeepAlive() time.Duration {
keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec) keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec)
if keepAliveEnv == "" { if keepAliveEnv == "" {
@ -25,7 +32,7 @@ func iceKeepAlive() time.Duration {
log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv) log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv)
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv) keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
if err != nil { if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) log.Warnf(msgWarnInvalidValue, keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
return iceKeepAliveDefault return iceKeepAliveDefault
} }
@ -41,7 +48,7 @@ func iceDisconnectedTimeout() time.Duration {
log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv) disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
if err != nil { if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) log.Warnf(msgWarnInvalidValue, disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
return iceDisconnectedTimeoutDefault return iceDisconnectedTimeoutDefault
} }
@ -57,14 +64,9 @@ func iceRelayAcceptanceMinWait() time.Duration {
log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv) log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv)
disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv) disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv)
if err != nil { if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) log.Warnf(msgWarnInvalidValue, iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault)
return iceRelayAcceptanceMinWaitDefault return iceRelayAcceptanceMinWaitDefault
} }
return time.Duration(disconnectedTimeoutSec) * time.Second return time.Duration(disconnectedTimeoutSec) * time.Second
} }
func hasICEForceRelayConn() bool {
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
return strings.ToLower(disconnectedTimeoutEnv) == "true"
}

View File

@ -1,6 +1,6 @@
//go:build !android //go:build !android
package peer package ice
import ( import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"

View File

@ -1,4 +1,4 @@
package peer package ice
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"

View File

@ -5,52 +5,20 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
"github.com/pion/randutil"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const (
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
lenUFrag = 16
lenPwd = 32
runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
var (
failedTimeout = 6 * time.Second
)
type ICEConfig struct {
// StunTurn is a list of STUN and TURN URLs
StunTurn *atomic.Value // []*stun.URI
// InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
// (e.g. if eth0 is in the list, host candidate of this interface won't be used)
InterfaceBlackList []string
DisableIPv6Discovery bool
UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux
NATExternalIPs []string
}
type ICEConnInfo struct { type ICEConnInfo struct {
RemoteConn net.Conn RemoteConn net.Conn
RosenpassPubKey []byte RosenpassPubKey []byte
@ -103,7 +71,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal
conn: callBacks, conn: callBacks,
} }
localUfrag, localPwd, err := generateICECredentials() localUfrag, localPwd, err := icemaker.GenerateICECredentials()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -125,10 +93,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
w.selectedPriority = connPriorityICEP2P w.selectedPriority = connPriorityICEP2P
preferredCandidateTypes = candidateTypesP2P() preferredCandidateTypes = icemaker.CandidateTypesP2P()
} else { } else {
w.selectedPriority = connPriorityICETurn w.selectedPriority = connPriorityICETurn
preferredCandidateTypes = candidateTypes() preferredCandidateTypes = icemaker.CandidateTypes()
} }
w.log.Debugf("recreate ICE agent") w.log.Debugf("recreate ICE agent")
@ -232,15 +200,10 @@ func (w *WorkerICE) Close() {
} }
} }
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) {
transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
if err != nil {
w.log.Errorf("failed to create pion's stdnet: %s", err)
}
w.sentExtraSrflx = false w.sentExtraSrflx = false
agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil { if err != nil {
return nil, fmt.Errorf("create agent: %w", err) return nil, fmt.Errorf("create agent: %w", err)
} }
@ -365,36 +328,6 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA
} }
} }
func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI),
CandidateTypes: candidateTypes,
InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList),
UDPMux: config.ICEConfig.UDPMux,
UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx,
NAT1To1IPs: config.ICEConfig.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: ufrag,
LocalPwd: pwd,
}
if config.ICEConfig.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
return ice.NewAgent(agentConfig)
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress() relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
@ -435,21 +368,6 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
return false return false
} }
func candidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
func candidateTypesP2P() []ice.CandidateType {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
func isRelayCandidate(candidate ice.Candidate) bool { func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay return candidate.Type() == ice.CandidateTypeRelay
} }
@ -460,16 +378,3 @@ func isRelayed(pair *ice.CandidatePair) bool {
} }
return false return false
} }
func generateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {
return "", "", err
}
pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha)
if err != nil {
return "", "", err
}
return ufrag, pwd, nil
}

View File

@ -31,6 +31,7 @@ type WorkerRelayCallbacks struct {
type WorkerRelay struct { type WorkerRelay struct {
log *log.Entry log *log.Entry
isController bool
config ConnConfig config ConnConfig
relayManager relayClient.ManagerService relayManager relayClient.ManagerService
callBacks WorkerRelayCallbacks callBacks WorkerRelayCallbacks
@ -44,9 +45,10 @@ type WorkerRelay struct {
relaySupportedOnRemotePeer atomic.Bool relaySupportedOnRemotePeer atomic.Bool
} }
func NewWorkerRelay(log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
log: log, log: log,
isController: ctrl,
config: config, config: config,
relayManager: relayManager, relayManager: relayManager,
callBacks: callbacks, callBacks: callbacks,
@ -80,6 +82,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Errorf("failed to open connection via Relay: %s", err) w.log.Errorf("failed to open connection via Relay: %s", err)
return return
} }
w.relayLock.Lock() w.relayLock.Lock()
w.relayedConn = relayedConn w.relayedConn = relayedConn
w.relayLock.Unlock() w.relayLock.Unlock()
@ -136,10 +139,6 @@ func (w *WorkerRelay) IsRelayConnectionSupportedWithPeer() bool {
return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally() return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally()
} }
func (w *WorkerRelay) IsController() bool {
return w.config.LocalKey > w.config.Key
}
func (w *WorkerRelay) RelayIsSupportedLocally() bool { func (w *WorkerRelay) RelayIsSupportedLocally() bool {
return w.relayManager.HasRelayAddress() return w.relayManager.HasRelayAddress()
} }
@ -212,7 +211,7 @@ func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
} }
func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string { func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string {
if w.IsController() { if w.isController {
return myRelayAddress return myRelayAddress
} }
return remoteRelayAddress return remoteRelayAddress

View File

@ -142,6 +142,7 @@ type Client struct {
muInstanceURL sync.Mutex muInstanceURL sync.Mutex
onDisconnectListener func() onDisconnectListener func()
onConnectedListener func()
listenerMutex sync.Mutex listenerMutex sync.Mutex
} }
@ -191,6 +192,7 @@ func (c *Client) Connect() error {
c.wgReadLoop.Add(1) c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn) go c.readLoop(c.relayConn)
go c.notifyConnected()
return nil return nil
} }
@ -238,6 +240,12 @@ func (c *Client) SetOnDisconnectListener(fn func()) {
c.onDisconnectListener = fn c.onDisconnectListener = fn
} }
func (c *Client) SetOnConnectedListener(fn func()) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onConnectedListener = fn
}
// HasConns returns true if there are connections. // HasConns returns true if there are connections.
func (c *Client) HasConns() bool { func (c *Client) HasConns() bool {
c.mu.Lock() c.mu.Lock()
@ -245,6 +253,12 @@ func (c *Client) HasConns() bool {
return len(c.conns) > 0 return len(c.conns) > 0
} }
func (c *Client) Ready() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.serviceIsRunning
}
// Close closes the connection to the relay server and all connections to other peers. // Close closes the connection to the relay server and all connections to other peers.
func (c *Client) Close() error { func (c *Client) Close() error {
return c.close(true) return c.close(true)
@ -363,9 +377,9 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.instanceURL = nil c.instanceURL = nil
c.muInstanceURL.Unlock() c.muInstanceURL.Unlock()
c.notifyDisconnected()
c.wgReadLoop.Done() c.wgReadLoop.Done()
_ = c.close(false) _ = c.close(false)
c.notifyDisconnected()
} }
func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) { func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
@ -544,6 +558,16 @@ func (c *Client) notifyDisconnected() {
go c.onDisconnectListener() go c.onDisconnectListener()
} }
func (c *Client) notifyConnected() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onConnectedListener == nil {
return
}
go c.onConnectedListener()
}
func (c *Client) writeCloseMsg() { func (c *Client) writeCloseMsg() {
msg := messages.MarshalCloseMsg() msg := messages.MarshalCloseMsg()
_, err := c.relayConn.Write(msg) _, err := c.relayConn.Write(msg)

View File

@ -29,6 +29,10 @@ func NewGuard(context context.Context, relayClient *Client) *Guard {
// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection // OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
func (g *Guard) OnDisconnected() { func (g *Guard) OnDisconnected() {
if g.quickReconnect() {
return
}
ticker := time.NewTicker(reconnectingTimeout) ticker := time.NewTicker(reconnectingTimeout)
defer ticker.Stop() defer ticker.Stop()
@ -46,3 +50,19 @@ func (g *Guard) OnDisconnected() {
} }
} }
} }
func (g *Guard) quickReconnect() bool {
ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond)
defer cancel()
<-ctx.Done()
if g.ctx.Err() != nil {
return false
}
if err := g.relayClient.Connect(); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
return false
}
return true
}

View File

@ -65,6 +65,7 @@ type Manager struct {
relayClientsMutex sync.RWMutex relayClientsMutex sync.RWMutex
onDisconnectedListeners map[string]*list.List onDisconnectedListeners map[string]*list.List
onReconnectedListenerFn func()
listenerLock sync.Mutex listenerLock sync.Mutex
} }
@ -101,6 +102,7 @@ func (m *Manager) Serve() error {
m.relayClient = client m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient) m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnConnectedListener(m.onServerConnected)
m.relayClient.SetOnDisconnectListener(func() { m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL) m.onServerDisconnected(client.connectionURL)
}) })
@ -138,6 +140,18 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
return netConn, err return netConn, err
} }
// Ready returns true if the home Relay client is connected to the relay server.
func (m *Manager) Ready() bool {
if m.relayClient == nil {
return false
}
return m.relayClient.Ready()
}
func (m *Manager) SetOnReconnectedListener(f func()) {
m.onReconnectedListenerFn = f
}
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed. // closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
@ -240,6 +254,13 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
return conn, nil return conn, nil
} }
func (m *Manager) onServerConnected() {
if m.onReconnectedListenerFn == nil {
return
}
go m.onReconnectedListenerFn()
}
func (m *Manager) onServerDisconnected(serverAddress string) { func (m *Manager) onServerDisconnected(serverAddress string) {
if serverAddress == m.relayClient.connectionURL { if serverAddress == m.relayClient.connectionURL {
go m.reconnectGuard.OnDisconnected() go m.reconnectGuard.OnDisconnected()

View File

@ -35,6 +35,7 @@ type Client interface {
WaitStreamConnected() WaitStreamConnected()
SendToStream(msg *proto.EncryptedMessage) error SendToStream(msg *proto.EncryptedMessage) error
Send(msg *proto.Message) error Send(msg *proto.Message) error
SetOnReconnectedListener(func())
} }
// UnMarshalCredential parses the credentials from the message and returns a Credential instance // UnMarshalCredential parses the credentials from the message and returns a Credential instance

View File

@ -43,6 +43,8 @@ type GrpcClient struct {
connStateCallback ConnStateNotifier connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex connStateCallbackLock sync.RWMutex
onReconnectedListenerFn func()
} }
func (c *GrpcClient) StreamConnected() bool { func (c *GrpcClient) StreamConnected() bool {
@ -181,12 +183,17 @@ func (c *GrpcClient) notifyStreamDisconnected() {
func (c *GrpcClient) notifyStreamConnected() { func (c *GrpcClient) notifyStreamConnected() {
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()
c.status = StreamConnected c.status = StreamConnected
if c.connectedCh != nil { if c.connectedCh != nil {
// there are goroutines waiting on this channel -> release them // there are goroutines waiting on this channel -> release them
close(c.connectedCh) close(c.connectedCh)
c.connectedCh = nil c.connectedCh = nil
} }
if c.onReconnectedListenerFn != nil {
c.onReconnectedListenerFn()
}
} }
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} { func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
@ -271,6 +278,13 @@ func (c *GrpcClient) WaitStreamConnected() {
} }
} }
func (c *GrpcClient) SetOnReconnectedListener(fn func()) {
c.mux.Lock()
defer c.mux.Unlock()
c.onReconnectedListenerFn = fn
}
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
// The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
// GrpcClient.connWg can be used to wait // GrpcClient.connWg can be used to wait

View File

@ -7,14 +7,20 @@ import (
) )
type MockClient struct { type MockClient struct {
CloseFunc func() error CloseFunc func() error
GetStatusFunc func() Status GetStatusFunc func() Status
StreamConnectedFunc func() bool StreamConnectedFunc func() bool
ReadyFunc func() bool ReadyFunc func() bool
WaitStreamConnectedFunc func() WaitStreamConnectedFunc func()
ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error
SendToStreamFunc func(msg *proto.EncryptedMessage) error SendToStreamFunc func(msg *proto.EncryptedMessage) error
SendFunc func(msg *proto.Message) error SendFunc func(msg *proto.Message) error
SetOnReconnectedListenerFunc func(f func())
}
// SetOnReconnectedListener sets the function to be called when the client reconnects.
func (sm *MockClient) SetOnReconnectedListener(_ func()) {
// Do nothing
} }
func (sm *MockClient) IsHealthy() bool { func (sm *MockClient) IsHealthy() bool {