[client] Fix controller re-connection (#2758)

Rethink the peer reconnection implementation
This commit is contained in:
Zoltan Papp 2024-10-24 11:43:14 +02:00 committed by GitHub
parent 869537c951
commit 4e918e55ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 813 additions and 523 deletions

View File

@ -79,9 +79,6 @@ jobs:
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
- name: Generate Shared Sock Test bin - name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
@ -98,7 +95,7 @@ jobs:
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin - name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/... run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
- run: chmod +x *testing.bin - run: chmod +x *testing.bin
@ -106,7 +103,7 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker - name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@ -143,7 +143,6 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
conn, ok := b.endpoints[ep.DstIP()] conn, ok := b.endpoints[ep.DstIP()]
b.endpointsMu.Unlock() b.endpointsMu.Unlock()
if !ok { if !ok {
log.Infof("failed to find endpoint for %s", ep.DstIP())
return b.StdNetBind.Send(bufs, ep) return b.StdNetBind.Send(bufs, ep)
} }

View File

@ -5,9 +5,9 @@ import (
"net" "net"
) )
const ( var (
portRangeStart = 3128 portRangeStart = 3128
portRangeEnd = 3228 portRangeEnd = portRangeStart + 100
) )
type portLookup struct { type portLookup struct {

View File

@ -17,6 +17,9 @@ func Test_portLookup_searchFreePort(t *testing.T) {
func Test_portLookup_on_allocated(t *testing.T) { func Test_portLookup_on_allocated(t *testing.T) {
pl := portLookup{} pl := portLookup{}
portRangeStart = 4128
portRangeEnd = portRangeStart + 100
allocatedPort, err := allocatePort(portRangeStart) allocatedPort, err := allocatePort(portRangeStart)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -22,9 +22,11 @@ func NewKernelFactory(wgPort int) *KernelFactory {
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f return f
} }
log.Infof("WireGuard Proxy Factory will produce eBPF proxy")
f.ebpfProxy = ebpfProxy f.ebpfProxy = ebpfProxy
return f return f
} }

View File

@ -1,6 +1,8 @@
package wgproxy package wgproxy
import ( import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
) )
@ -10,6 +12,7 @@ type KernelFactory struct {
} }
func NewKernelFactory(wgPort int) *KernelFactory { func NewKernelFactory(wgPort int) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{ f := &KernelFactory{
wgPort: wgPort, wgPort: wgPort,
} }

View File

@ -1,6 +1,8 @@
package wgproxy package wgproxy
import ( import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
) )
@ -10,6 +12,7 @@ type USPFactory struct {
} }
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{ f := &USPFactory{
bind: iceBind, bind: iceBind,
} }

View File

@ -2,14 +2,16 @@ package udp
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors" cerrors "github.com/netbirdio/netbird/client/errors"
) )
// WGUDPProxy proxies // WGUDPProxy proxies
@ -121,7 +123,7 @@ func (p *WGUDPProxy) close() error {
if err := p.localConn.Close(); err != nil { if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
} }
return errors.FormatErrorOrNil(result) return cerrors.FormatErrorOrNil(result)
} }
// proxyToRemote proxies from Wireguard to the RemoteKey // proxyToRemote proxies from Wireguard to the RemoteKey
@ -160,18 +162,16 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
defer func() { defer func() {
if err := p.close(); err != nil { if err := p.close(); err != nil {
if !errors.Is(err, io.EOF) {
log.Warnf("error in proxy to local loop: %s", err) log.Warnf("error in proxy to local loop: %s", err)
} }
}
}() }()
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
n, err := p.remoteConn.Read(buf) n, err := p.remoteConnRead(ctx, buf)
if err != nil { if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }
@ -193,3 +193,15 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
} }
} }
} }
func (p *WGUDPProxy) remoteConnRead(ctx context.Context, buf []byte) (n int, err error) {
n, err = p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.LocalAddr(), err)
return
}
return
}

View File

@ -30,6 +30,8 @@ import (
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
@ -168,6 +170,7 @@ type Engine struct {
relayManager *relayClient.Manager relayManager *relayClient.Manager
stateManager *statemanager.Manager stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@ -263,6 +266,10 @@ func (e *Engine) Stop() error {
e.routeManager.Stop(e.stateManager) e.routeManager.Stop(e.stateManager)
} }
if e.srWatcher != nil {
e.srWatcher.Close()
}
err := e.removeAllPeers() err := e.removeAllPeers()
if err != nil { if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
@ -389,6 +396,18 @@ func (e *Engine) Start() error {
return fmt.Errorf("initialize dns server: %w", err) return fmt.Errorf("initialize dns server: %w", err)
} }
iceCfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start()
e.receiveSignalEvents() e.receiveSignalEvents()
e.receiveManagementEvents() e.receiveManagementEvents()
e.receiveProbeEvents() e.receiveProbeEvents()
@ -971,7 +990,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
RosenpassPubKey: e.getRosenpassPubKey(), RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(), RosenpassAddr: e.getRosenpassAddr(),
ICEConfig: peer.ICEConfig{ ICEConfig: icemaker.Config{
StunTurn: &e.stunTurn, StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList, InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery, DisableIPv6Discovery: e.config.DisableIPv6Discovery,
@ -981,7 +1000,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
}, },
} }
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -29,6 +29,8 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@ -258,6 +260,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
} }
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
type testCase struct { type testCase struct {
name string name string

View File

@ -10,7 +10,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/cenkalti/backoff/v4"
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@ -18,6 +17,8 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@ -32,8 +33,6 @@ const (
connPriorityRelay ConnPriority = 1 connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 1 connPriorityICETurn ConnPriority = 1
connPriorityICEP2P ConnPriority = 2 connPriorityICEP2P ConnPriority = 2
reconnectMaxElapsedTime = 30 * time.Minute
) )
type WgConfig struct { type WgConfig struct {
@ -63,7 +62,7 @@ type ConnConfig struct {
RosenpassAddr string RosenpassAddr string
// ICEConfig ICE protocol configuration // ICEConfig ICE protocol configuration
ICEConfig ICEConfig ICEConfig icemaker.Config
} }
type WorkerCallbacks struct { type WorkerCallbacks struct {
@ -106,16 +105,12 @@ type Conn struct {
wgProxyICE wgproxy.Proxy wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy wgProxyRelay wgproxy.Proxy
// for reconnection operations guard *guard.Guard
iCEDisconnected chan bool
relayDisconnected chan bool
connMonitor *ConnMonitor
reconnectCh <-chan struct{}
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps)
if err != nil { if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err) log.Errorf("failed to parse allowedIPS: %v", err)
@ -137,18 +132,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
allowedNet: allowedNet.String(), allowedNet: allowedNet.String(),
statusRelay: NewAtomicConnStatus(), statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(),
iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1),
} }
conn.connMonitor, conn.reconnectCh = NewConnMonitor(
signaler,
iFaceDiscover,
config,
conn.relayDisconnected,
conn.iCEDisconnected,
)
rFns := WorkerRelayCallbacks{ rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady, OnConnReady: conn.relayConnectionIsReady,
OnDisconnected: conn.onWorkerRelayStateDisconnected, OnDisconnected: conn.onWorkerRelayStateDisconnected,
@ -159,7 +144,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
OnStatusChanged: conn.onWorkerICEStateDisconnected, OnStatusChanged: conn.onWorkerICEStateDisconnected,
} }
conn.workerRelay = NewWorkerRelay(connLog, config, relayManager, rFns) ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
@ -174,6 +160,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
} }
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
go conn.handshaker.Listen() go conn.handshaker.Listen()
return conn, nil return conn, nil
@ -184,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
// be used. // be used.
func (conn *Conn) Open() { func (conn *Conn) Open() {
conn.log.Debugf("open connection to peer") conn.log.Debugf("open connection to peer")
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.opened = true conn.opened = true
@ -200,24 +189,19 @@ func (conn *Conn) Open() {
conn.log.Warnf("error while updating the state err: %v", err) conn.log.Warnf("error while updating the state err: %v", err)
} }
go conn.startHandshakeAndReconnect() go conn.startHandshakeAndReconnect(conn.ctx)
} }
func (conn *Conn) startHandshakeAndReconnect() { func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
conn.waitInitialRandomSleepTime() conn.waitInitialRandomSleepTime(ctx)
err := conn.handshaker.sendOffer() err := conn.handshaker.sendOffer()
if err != nil { if err != nil {
conn.log.Errorf("failed to send initial offer: %v", err) conn.log.Errorf("failed to send initial offer: %v", err)
} }
go conn.connMonitor.Start(conn.ctx) go conn.guard.Start(ctx)
go conn.listenGuardEvent(ctx)
if conn.workerRelay.IsController() {
conn.reconnectLoopWithRetry()
} else {
conn.reconnectLoopForOnDisconnectedEvent()
}
} }
// Close closes this peer Conn issuing a close event to the Conn closeCh // Close closes this peer Conn issuing a close event to the Conn closeCh
@ -316,104 +300,6 @@ func (conn *Conn) GetKey() string {
return conn.config.Key return conn.config.Key
} }
func (conn *Conn) reconnectLoopWithRetry() {
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
select {
case <-conn.ctx.Done():
return
case <-time.After(3 * time.Second):
}
ticker := conn.prepareExponentTicker()
defer ticker.Stop()
time.Sleep(1 * time.Second)
for {
select {
case t := <-ticker.C:
if t.IsZero() {
// in case if the ticker has been canceled by context then avoid the temporary loop
return
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
}
} else {
if conn.statusICE.Get() == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
}
}
// checks if there is peer connection is established via relay or ice
if conn.isConnected() {
continue
}
err := conn.handshaker.sendOffer()
if err != nil {
conn.log.Errorf("failed to do handshake: %v", err)
}
case <-conn.reconnectCh:
ticker.Stop()
ticker = conn.prepareExponentTicker()
case <-conn.ctx.Done():
conn.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: conn.config.Timeout,
MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, conn.ctx)
ticker := backoff.NewTicker(bo)
<-ticker.C // consume the initial tick what is happening right after the ticker has been created
return ticker
}
// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not
// mean that to switch to it. We always force to use the higher priority connection.
func (conn *Conn) reconnectLoopForOnDisconnectedEvent() {
for {
select {
case changed := <-conn.relayDisconnected:
if !changed {
continue
}
conn.log.Debugf("Relay state changed, try to send new offer")
case changed := <-conn.iCEDisconnected:
if !changed {
continue
}
conn.log.Debugf("ICE state changed, try to send new offer")
case <-conn.ctx.Done():
conn.log.Debugf("context is done, stop reconnect loop")
return
}
err := conn.handshaker.SendOffer()
if err != nil {
conn.log.Errorf("failed to do handshake: %v", err)
}
}
}
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
conn.mu.Lock() conn.mu.Lock()
@ -513,7 +399,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
changed := conn.statusICE.Get() != newState && newState != StatusConnecting changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState) conn.statusICE.Set(newState)
conn.notifyReconnectLoopICEDisconnected(changed) conn.guard.SetICEConnDisconnected(changed)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@ -604,7 +490,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
changed := conn.statusRelay.Get() != StatusDisconnected changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.Set(StatusDisconnected)
conn.notifyReconnectLoopRelayDisconnected(changed) conn.guard.SetRelayedConnDisconnected(changed)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@ -617,6 +503,20 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
} }
} }
func (conn *Conn) listenGuardEvent(ctx context.Context) {
for {
select {
case <-conn.guard.Reconnect:
conn.log.Debugf("send offer to peer")
if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Errorf("failed to send offer: %v", err)
}
case <-ctx.Done():
return
}
}
}
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error { func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
return conn.config.WgConfig.WgInterface.UpdatePeer( return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey, conn.config.WgConfig.RemoteKey,
@ -693,7 +593,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
} }
} }
func (conn *Conn) waitInitialRandomSleepTime() { func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
minWait := 100 minWait := 100
maxWait := 800 maxWait := 800
duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond
@ -702,7 +602,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
defer timeout.Stop() defer timeout.Stop()
select { select {
case <-conn.ctx.Done(): case <-ctx.Done():
case <-timeout.C: case <-timeout.C:
} }
} }
@ -731,11 +631,17 @@ func (conn *Conn) evalStatus() ConnStatus {
return StatusDisconnected return StatusDisconnected
} }
func (conn *Conn) isConnected() bool { func (conn *Conn) isConnectedOnAllWay() (connected bool) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting { defer func() {
if !connected {
conn.logTraceConnState()
}
}()
if conn.statusICE.Get() == StatusDisconnected {
return false return false
} }
@ -805,20 +711,6 @@ func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
} }
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
select {
case conn.relayDisconnected <- changed:
default:
}
}
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
select {
case conn.iCEDisconnected <- changed:
default:
}
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err) conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil { if wgProxy != nil {
@ -831,6 +723,18 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
} }
} }
func (conn *Conn) logTraceConnState() {
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
} else {
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
}
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil return remoteRosenpassPubKey != nil
} }

View File

@ -1,212 +0,0 @@
package peer
import (
"context"
"fmt"
"sync"
"time"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
const (
signalerMonitorPeriod = 5 * time.Second
candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
)
type ConnMonitor struct {
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
config ConnConfig
relayDisconnected chan bool
iCEDisconnected chan bool
reconnectCh chan struct{}
currentCandidates []ice.Candidate
candidatesMu sync.Mutex
}
func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) {
reconnectCh := make(chan struct{}, 1)
cm := &ConnMonitor{
signaler: signaler,
iFaceDiscover: iFaceDiscover,
config: config,
relayDisconnected: relayDisconnected,
iCEDisconnected: iCEDisconnected,
reconnectCh: reconnectCh,
}
return cm, reconnectCh
}
func (cm *ConnMonitor) Start(ctx context.Context) {
signalerReady := make(chan struct{}, 1)
go cm.monitorSignalerReady(ctx, signalerReady)
localCandidatesChanged := make(chan struct{}, 1)
go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged)
for {
select {
case changed := <-cm.relayDisconnected:
if !changed {
continue
}
log.Debugf("Relay state changed, triggering reconnect")
cm.triggerReconnect()
case changed := <-cm.iCEDisconnected:
if !changed {
continue
}
log.Debugf("ICE state changed, triggering reconnect")
cm.triggerReconnect()
case <-signalerReady:
log.Debugf("Signaler became ready, triggering reconnect")
cm.triggerReconnect()
case <-localCandidatesChanged:
log.Debugf("Local candidates changed, triggering reconnect")
cm.triggerReconnect()
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) {
if cm.signaler == nil {
return
}
ticker := time.NewTicker(signalerMonitorPeriod)
defer ticker.Stop()
lastReady := true
for {
select {
case <-ticker.C:
currentReady := cm.signaler.Ready()
if !lastReady && currentReady {
select {
case signalerReady <- struct{}{}:
default:
}
}
lastReady = currentReady
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) {
ufrag, pwd, err := generateICECredentials()
if err != nil {
log.Warnf("Failed to generate ICE credentials: %v", err)
return
}
ticker := time.NewTicker(candidatesMonitorPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil {
log.Warnf("Failed to handle candidate tick: %v", err)
}
case <-ctx.Done():
return
}
}
}
func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error {
log.Debugf("Gathering ICE candidates")
transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}
agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return fmt.Errorf("create ICE agent: %w", err)
}
defer func() {
if err := agent.Close(); err != nil {
log.Warnf("Failed to close ICE agent: %v", err)
}
}()
gatherDone := make(chan struct{})
err = agent.OnCandidate(func(c ice.Candidate) {
log.Tracef("Got candidate: %v", c)
if c == nil {
close(gatherDone)
}
})
if err != nil {
return fmt.Errorf("set ICE candidate handler: %w", err)
}
if err := agent.GatherCandidates(); err != nil {
return fmt.Errorf("gather ICE candidates: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
defer cancel()
select {
case <-ctx.Done():
return fmt.Errorf("wait for gathering: %w", ctx.Err())
case <-gatherDone:
}
candidates, err := agent.GetLocalCandidates()
if err != nil {
return fmt.Errorf("get local candidates: %w", err)
}
log.Tracef("Got candidates: %v", candidates)
if changed := cm.updateCandidates(candidates); changed {
select {
case localCandidatesChanged <- struct{}{}:
default:
}
}
return nil
}
func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
cm.candidatesMu.Lock()
defer cm.candidatesMu.Unlock()
if len(cm.currentCandidates) != len(newCandidates) {
cm.currentCandidates = newCandidates
return true
}
for i, candidate := range cm.currentCandidates {
if candidate.Address() != newCandidates[i].Address() {
cm.currentCandidates = newCandidates
return true
}
}
return false
}
func (cm *ConnMonitor) triggerReconnect() {
select {
case cm.reconnectCh <- struct{}{}:
default:
}
}

View File

@ -10,6 +10,8 @@ import (
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer/guard"
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -19,7 +21,7 @@ var connConf = ConnConfig{
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
Timeout: time.Second, Timeout: time.Second,
LocalWgPort: 51820, LocalWgPort: 51820,
ICEConfig: ICEConfig{ ICEConfig: ice.Config{
InterfaceBlackList: nil, InterfaceBlackList: nil,
}, },
} }
@ -43,7 +45,8 @@ func TestNewConn_interfaceFilter(t *testing.T) {
} }
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher)
if err != nil { if err != nil {
return return
} }
@ -54,7 +57,8 @@ func TestConn_GetKey(t *testing.T) {
} }
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
if err != nil { if err != nil {
return return
} }
@ -87,7 +91,8 @@ func TestConn_OnRemoteOffer(t *testing.T) {
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
if err != nil { if err != nil {
return return
} }
@ -119,7 +124,8 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
if err != nil { if err != nil {
return return
} }

View File

@ -0,0 +1,194 @@
package guard
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
)
const (
reconnectMaxElapsedTime = 30 * time.Minute
)
type isConnectedFunc func() bool
// Guard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues.
// Watch these events:
// - Relay client reconnected to home server
// - Signal server connection state changed
// - ICE connection disconnected
// - Relayed connection disconnected
// - ICE candidate changes
type Guard struct {
Reconnect chan struct{}
log *log.Entry
isController bool
isConnectedOnAllWay isConnectedFunc
timeout time.Duration
srWatcher *SRWatcher
relayedConnDisconnected chan bool
iCEConnDisconnected chan bool
}
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{
Reconnect: make(chan struct{}, 1),
log: log,
isController: isController,
isConnectedOnAllWay: isConnectedFn,
timeout: timeout,
srWatcher: srWatcher,
relayedConnDisconnected: make(chan bool, 1),
iCEConnDisconnected: make(chan bool, 1),
}
}
func (g *Guard) Start(ctx context.Context) {
if g.isController {
g.reconnectLoopWithRetry(ctx)
} else {
g.listenForDisconnectEvents(ctx)
}
}
func (g *Guard) SetRelayedConnDisconnected(changed bool) {
select {
case g.relayedConnDisconnected <- changed:
default:
}
}
func (g *Guard) SetICEConnDisconnected(changed bool) {
select {
case g.iCEConnDisconnected <- changed:
default:
}
}
// reconnectLoopWithRetry periodically check (max 30 min) the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
waitForInitialConnectionTry(ctx)
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
ticker := g.prepareExponentTicker(ctx)
defer ticker.Stop()
tickerChannel := ticker.C
g.log.Infof("start reconnect loop...")
for {
select {
case t := <-tickerChannel:
if t.IsZero() {
g.log.Infof("retry timed out, stop periodic offer sending")
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
tickerChannel = make(<-chan time.Time)
continue
}
if !g.isConnectedOnAllWay() {
g.triggerOfferSending()
}
case changed := <-g.relayedConnDisconnected:
if !changed {
continue
}
g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case changed := <-g.iCEConnDisconnected:
if !changed {
continue
}
g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-srReconnectedChan:
g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not
// mean that to switch to it. We always force to use the higher priority connection.
func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
g.log.Infof("start listen for reconnect events...")
for {
select {
case changed := <-g.relayedConnDisconnected:
if !changed {
continue
}
g.log.Debugf("Relay connection changed, triggering reconnect")
g.triggerOfferSending()
case changed := <-g.iCEConnDisconnected:
if !changed {
continue
}
g.log.Debugf("ICE state changed, try to send new offer")
g.triggerOfferSending()
case <-srReconnectedChan:
g.triggerOfferSending()
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
ticker := backoff.NewTicker(bo)
<-ticker.C // consume the initial tick what is happening right after the ticker has been created
return ticker
}
func (g *Guard) triggerOfferSending() {
select {
case g.Reconnect <- struct{}{}:
default:
}
}
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
func waitForInitialConnectionTry(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-time.After(3 * time.Second):
}
}

View File

@ -0,0 +1,135 @@
package guard
import (
"context"
"fmt"
"sync"
"time"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
const (
candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
)
type ICEMonitor struct {
ReconnectCh chan struct{}
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig icemaker.Config
currentCandidates []ice.Candidate
candidatesMu sync.Mutex
}
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
cm := &ICEMonitor{
ReconnectCh: make(chan struct{}, 1),
iFaceDiscover: iFaceDiscover,
iceConfig: config,
}
return cm
}
func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
ufrag, pwd, err := icemaker.GenerateICECredentials()
if err != nil {
log.Warnf("Failed to generate ICE credentials: %v", err)
return
}
ticker := time.NewTicker(candidatesMonitorPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
changed, err := cm.handleCandidateTick(ctx, ufrag, pwd)
if err != nil {
log.Warnf("Failed to check ICE changes: %v", err)
continue
}
if changed {
onChanged()
}
case <-ctx.Done():
return
}
}
}
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
log.Debugf("Gathering ICE candidates")
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return false, fmt.Errorf("create ICE agent: %w", err)
}
defer func() {
if err := agent.Close(); err != nil {
log.Warnf("Failed to close ICE agent: %v", err)
}
}()
gatherDone := make(chan struct{})
err = agent.OnCandidate(func(c ice.Candidate) {
log.Tracef("Got candidate: %v", c)
if c == nil {
close(gatherDone)
}
})
if err != nil {
return false, fmt.Errorf("set ICE candidate handler: %w", err)
}
if err := agent.GatherCandidates(); err != nil {
return false, fmt.Errorf("gather ICE candidates: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
defer cancel()
select {
case <-ctx.Done():
return false, fmt.Errorf("wait for gathering timed out")
case <-gatherDone:
}
candidates, err := agent.GetLocalCandidates()
if err != nil {
return false, fmt.Errorf("get local candidates: %w", err)
}
log.Tracef("Got candidates: %v", candidates)
return cm.updateCandidates(candidates), nil
}
func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
cm.candidatesMu.Lock()
defer cm.candidatesMu.Unlock()
if len(cm.currentCandidates) != len(newCandidates) {
cm.currentCandidates = newCandidates
return true
}
for i, candidate := range cm.currentCandidates {
if candidate.Address() != newCandidates[i].Address() {
cm.currentCandidates = newCandidates
return true
}
}
return false
}
func candidateTypesP2P() []ice.CandidateType {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}

View File

@ -0,0 +1,119 @@
package guard
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
type chNotifier interface {
SetOnReconnectedListener(func())
Ready() bool
}
type SRWatcher struct {
signalClient chNotifier
relayManager chNotifier
listeners map[chan struct{}]struct{}
mu sync.Mutex
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
cancelIceMonitor context.CancelFunc
}
// NewSRWatcher creates a new SRWatcher. This watcher will notify the listeners when the ICE candidates change or the
// Relay connection is reconnected or the Signal client reconnected.
func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscover stdnet.ExternalIFaceDiscover, iceConfig ice.Config) *SRWatcher {
srw := &SRWatcher{
signalClient: signalClient,
relayManager: relayManager,
iFaceDiscover: iFaceDiscover,
iceConfig: iceConfig,
listeners: make(map[chan struct{}]struct{}),
}
return srw
}
func (w *SRWatcher) Start() {
w.mu.Lock()
defer w.mu.Unlock()
if w.cancelIceMonitor != nil {
return
}
ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
go iceMonitor.Start(ctx, w.onICEChanged)
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)
}
func (w *SRWatcher) Close() {
w.mu.Lock()
defer w.mu.Unlock()
if w.cancelIceMonitor == nil {
return
}
w.cancelIceMonitor()
w.signalClient.SetOnReconnectedListener(nil)
w.relayManager.SetOnReconnectedListener(nil)
}
func (w *SRWatcher) NewListener() chan struct{} {
w.mu.Lock()
defer w.mu.Unlock()
listenerChan := make(chan struct{}, 1)
w.listeners[listenerChan] = struct{}{}
return listenerChan
}
func (w *SRWatcher) RemoveListener(listenerChan chan struct{}) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.listeners, listenerChan)
close(listenerChan)
}
func (w *SRWatcher) onICEChanged() {
if !w.signalClient.Ready() {
return
}
log.Infof("network changes detected by ICE agent")
w.notify()
}
func (w *SRWatcher) onReconnected() {
if !w.signalClient.Ready() {
return
}
if !w.relayManager.Ready() {
return
}
log.Infof("reconnected to Signal or Relay server")
w.notify()
}
func (w *SRWatcher) notify() {
w.mu.Lock()
defer w.mu.Unlock()
for listener := range w.listeners {
select {
case listener <- struct{}{}:
default:
}
}
}

View File

@ -0,0 +1,89 @@
package ice
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/pion/ice/v3"
"github.com/pion/randutil"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"runtime"
"time"
)
const (
lenUFrag = 16
lenPwd = 32
runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
)
var (
failedTimeout = 6 * time.Second
)
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: config.StunTurn.Load().([]*stun.URI),
CandidateTypes: candidateTypes,
InterfaceFilter: stdnet.InterfaceFilter(config.InterfaceBlackList),
UDPMux: config.UDPMux,
UDPMuxSrflx: config.UDPMuxSrflx,
NAT1To1IPs: config.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: ufrag,
LocalPwd: pwd,
}
if config.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
return ice.NewAgent(agentConfig)
}
func GenerateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {
return "", "", err
}
pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha)
if err != nil {
return "", "", err
}
return ufrag, pwd, nil
}
func CandidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
func CandidateTypesP2P() []ice.CandidateType {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}

View File

@ -0,0 +1,22 @@
package ice
import (
"sync/atomic"
"github.com/pion/ice/v3"
)
type Config struct {
// StunTurn is a list of STUN and TURN URLs
StunTurn *atomic.Value // []*stun.URI
// InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
// (e.g. if eth0 is in the list, host candidate of this interface won't be used)
InterfaceBlackList []string
DisableIPv6Discovery bool
UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux
NATExternalIPs []string
}

View File

@ -1,4 +1,4 @@
package peer package ice
import ( import (
"os" "os"
@ -10,12 +10,19 @@ import (
) )
const ( const (
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC"
envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC"
envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC" envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC"
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
msgWarnInvalidValue = "invalid value %s set for %s, using default %v"
) )
func hasICEForceRelayConn() bool {
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
return strings.ToLower(disconnectedTimeoutEnv) == "true"
}
func iceKeepAlive() time.Duration { func iceKeepAlive() time.Duration {
keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec) keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec)
if keepAliveEnv == "" { if keepAliveEnv == "" {
@ -25,7 +32,7 @@ func iceKeepAlive() time.Duration {
log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv) log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv)
keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv) keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv)
if err != nil { if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) log.Warnf(msgWarnInvalidValue, keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault)
return iceKeepAliveDefault return iceKeepAliveDefault
} }
@ -41,7 +48,7 @@ func iceDisconnectedTimeout() time.Duration {
log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv)
disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv) disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv)
if err != nil { if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) log.Warnf(msgWarnInvalidValue, disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault)
return iceDisconnectedTimeoutDefault return iceDisconnectedTimeoutDefault
} }
@ -57,14 +64,9 @@ func iceRelayAcceptanceMinWait() time.Duration {
log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv) log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv)
disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv) disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv)
if err != nil { if err != nil {
log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) log.Warnf(msgWarnInvalidValue, iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault)
return iceRelayAcceptanceMinWaitDefault return iceRelayAcceptanceMinWaitDefault
} }
return time.Duration(disconnectedTimeoutSec) * time.Second return time.Duration(disconnectedTimeoutSec) * time.Second
} }
func hasICEForceRelayConn() bool {
disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn)
return strings.ToLower(disconnectedTimeoutEnv) == "true"
}

View File

@ -1,6 +1,6 @@
//go:build !android //go:build !android
package peer package ice
import ( import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"

View File

@ -1,4 +1,4 @@
package peer package ice
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"

View File

@ -5,52 +5,20 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
"github.com/pion/randutil"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const (
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
lenUFrag = 16
lenPwd = 32
runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
var (
failedTimeout = 6 * time.Second
)
type ICEConfig struct {
// StunTurn is a list of STUN and TURN URLs
StunTurn *atomic.Value // []*stun.URI
// InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
// (e.g. if eth0 is in the list, host candidate of this interface won't be used)
InterfaceBlackList []string
DisableIPv6Discovery bool
UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux
NATExternalIPs []string
}
type ICEConnInfo struct { type ICEConnInfo struct {
RemoteConn net.Conn RemoteConn net.Conn
RosenpassPubKey []byte RosenpassPubKey []byte
@ -103,7 +71,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal
conn: callBacks, conn: callBacks,
} }
localUfrag, localPwd, err := generateICECredentials() localUfrag, localPwd, err := icemaker.GenerateICECredentials()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -125,10 +93,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
w.selectedPriority = connPriorityICEP2P w.selectedPriority = connPriorityICEP2P
preferredCandidateTypes = candidateTypesP2P() preferredCandidateTypes = icemaker.CandidateTypesP2P()
} else { } else {
w.selectedPriority = connPriorityICETurn w.selectedPriority = connPriorityICETurn
preferredCandidateTypes = candidateTypes() preferredCandidateTypes = icemaker.CandidateTypes()
} }
w.log.Debugf("recreate ICE agent") w.log.Debugf("recreate ICE agent")
@ -232,15 +200,10 @@ func (w *WorkerICE) Close() {
} }
} }
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) {
transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
if err != nil {
w.log.Errorf("failed to create pion's stdnet: %s", err)
}
w.sentExtraSrflx = false w.sentExtraSrflx = false
agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil { if err != nil {
return nil, fmt.Errorf("create agent: %w", err) return nil, fmt.Errorf("create agent: %w", err)
} }
@ -365,36 +328,6 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA
} }
} }
func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI),
CandidateTypes: candidateTypes,
InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList),
UDPMux: config.ICEConfig.UDPMux,
UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx,
NAT1To1IPs: config.ICEConfig.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: ufrag,
LocalPwd: pwd,
}
if config.ICEConfig.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
return ice.NewAgent(agentConfig)
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress() relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
@ -435,21 +368,6 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
return false return false
} }
func candidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
func candidateTypesP2P() []ice.CandidateType {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
func isRelayCandidate(candidate ice.Candidate) bool { func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay return candidate.Type() == ice.CandidateTypeRelay
} }
@ -460,16 +378,3 @@ func isRelayed(pair *ice.CandidatePair) bool {
} }
return false return false
} }
func generateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {
return "", "", err
}
pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha)
if err != nil {
return "", "", err
}
return ufrag, pwd, nil
}

View File

@ -31,6 +31,7 @@ type WorkerRelayCallbacks struct {
type WorkerRelay struct { type WorkerRelay struct {
log *log.Entry log *log.Entry
isController bool
config ConnConfig config ConnConfig
relayManager relayClient.ManagerService relayManager relayClient.ManagerService
callBacks WorkerRelayCallbacks callBacks WorkerRelayCallbacks
@ -44,9 +45,10 @@ type WorkerRelay struct {
relaySupportedOnRemotePeer atomic.Bool relaySupportedOnRemotePeer atomic.Bool
} }
func NewWorkerRelay(log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
log: log, log: log,
isController: ctrl,
config: config, config: config,
relayManager: relayManager, relayManager: relayManager,
callBacks: callbacks, callBacks: callbacks,
@ -80,6 +82,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Errorf("failed to open connection via Relay: %s", err) w.log.Errorf("failed to open connection via Relay: %s", err)
return return
} }
w.relayLock.Lock() w.relayLock.Lock()
w.relayedConn = relayedConn w.relayedConn = relayedConn
w.relayLock.Unlock() w.relayLock.Unlock()
@ -136,10 +139,6 @@ func (w *WorkerRelay) IsRelayConnectionSupportedWithPeer() bool {
return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally() return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally()
} }
func (w *WorkerRelay) IsController() bool {
return w.config.LocalKey > w.config.Key
}
func (w *WorkerRelay) RelayIsSupportedLocally() bool { func (w *WorkerRelay) RelayIsSupportedLocally() bool {
return w.relayManager.HasRelayAddress() return w.relayManager.HasRelayAddress()
} }
@ -212,7 +211,7 @@ func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
} }
func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string { func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string {
if w.IsController() { if w.isController {
return myRelayAddress return myRelayAddress
} }
return remoteRelayAddress return remoteRelayAddress

View File

@ -142,6 +142,7 @@ type Client struct {
muInstanceURL sync.Mutex muInstanceURL sync.Mutex
onDisconnectListener func() onDisconnectListener func()
onConnectedListener func()
listenerMutex sync.Mutex listenerMutex sync.Mutex
} }
@ -191,6 +192,7 @@ func (c *Client) Connect() error {
c.wgReadLoop.Add(1) c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn) go c.readLoop(c.relayConn)
go c.notifyConnected()
return nil return nil
} }
@ -238,6 +240,12 @@ func (c *Client) SetOnDisconnectListener(fn func()) {
c.onDisconnectListener = fn c.onDisconnectListener = fn
} }
func (c *Client) SetOnConnectedListener(fn func()) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onConnectedListener = fn
}
// HasConns returns true if there are connections. // HasConns returns true if there are connections.
func (c *Client) HasConns() bool { func (c *Client) HasConns() bool {
c.mu.Lock() c.mu.Lock()
@ -245,6 +253,12 @@ func (c *Client) HasConns() bool {
return len(c.conns) > 0 return len(c.conns) > 0
} }
func (c *Client) Ready() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.serviceIsRunning
}
// Close closes the connection to the relay server and all connections to other peers. // Close closes the connection to the relay server and all connections to other peers.
func (c *Client) Close() error { func (c *Client) Close() error {
return c.close(true) return c.close(true)
@ -363,9 +377,9 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.instanceURL = nil c.instanceURL = nil
c.muInstanceURL.Unlock() c.muInstanceURL.Unlock()
c.notifyDisconnected()
c.wgReadLoop.Done() c.wgReadLoop.Done()
_ = c.close(false) _ = c.close(false)
c.notifyDisconnected()
} }
func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) { func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
@ -544,6 +558,16 @@ func (c *Client) notifyDisconnected() {
go c.onDisconnectListener() go c.onDisconnectListener()
} }
func (c *Client) notifyConnected() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onConnectedListener == nil {
return
}
go c.onConnectedListener()
}
func (c *Client) writeCloseMsg() { func (c *Client) writeCloseMsg() {
msg := messages.MarshalCloseMsg() msg := messages.MarshalCloseMsg()
_, err := c.relayConn.Write(msg) _, err := c.relayConn.Write(msg)

View File

@ -29,6 +29,10 @@ func NewGuard(context context.Context, relayClient *Client) *Guard {
// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection // OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
func (g *Guard) OnDisconnected() { func (g *Guard) OnDisconnected() {
if g.quickReconnect() {
return
}
ticker := time.NewTicker(reconnectingTimeout) ticker := time.NewTicker(reconnectingTimeout)
defer ticker.Stop() defer ticker.Stop()
@ -46,3 +50,19 @@ func (g *Guard) OnDisconnected() {
} }
} }
} }
func (g *Guard) quickReconnect() bool {
ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond)
defer cancel()
<-ctx.Done()
if g.ctx.Err() != nil {
return false
}
if err := g.relayClient.Connect(); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
return false
}
return true
}

View File

@ -65,6 +65,7 @@ type Manager struct {
relayClientsMutex sync.RWMutex relayClientsMutex sync.RWMutex
onDisconnectedListeners map[string]*list.List onDisconnectedListeners map[string]*list.List
onReconnectedListenerFn func()
listenerLock sync.Mutex listenerLock sync.Mutex
} }
@ -101,6 +102,7 @@ func (m *Manager) Serve() error {
m.relayClient = client m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient) m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnConnectedListener(m.onServerConnected)
m.relayClient.SetOnDisconnectListener(func() { m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL) m.onServerDisconnected(client.connectionURL)
}) })
@ -138,6 +140,18 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
return netConn, err return netConn, err
} }
// Ready returns true if the home Relay client is connected to the relay server.
func (m *Manager) Ready() bool {
if m.relayClient == nil {
return false
}
return m.relayClient.Ready()
}
func (m *Manager) SetOnReconnectedListener(f func()) {
m.onReconnectedListenerFn = f
}
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed. // closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
@ -240,6 +254,13 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
return conn, nil return conn, nil
} }
func (m *Manager) onServerConnected() {
if m.onReconnectedListenerFn == nil {
return
}
go m.onReconnectedListenerFn()
}
func (m *Manager) onServerDisconnected(serverAddress string) { func (m *Manager) onServerDisconnected(serverAddress string) {
if serverAddress == m.relayClient.connectionURL { if serverAddress == m.relayClient.connectionURL {
go m.reconnectGuard.OnDisconnected() go m.reconnectGuard.OnDisconnected()

View File

@ -35,6 +35,7 @@ type Client interface {
WaitStreamConnected() WaitStreamConnected()
SendToStream(msg *proto.EncryptedMessage) error SendToStream(msg *proto.EncryptedMessage) error
Send(msg *proto.Message) error Send(msg *proto.Message) error
SetOnReconnectedListener(func())
} }
// UnMarshalCredential parses the credentials from the message and returns a Credential instance // UnMarshalCredential parses the credentials from the message and returns a Credential instance

View File

@ -43,6 +43,8 @@ type GrpcClient struct {
connStateCallback ConnStateNotifier connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex connStateCallbackLock sync.RWMutex
onReconnectedListenerFn func()
} }
func (c *GrpcClient) StreamConnected() bool { func (c *GrpcClient) StreamConnected() bool {
@ -181,12 +183,17 @@ func (c *GrpcClient) notifyStreamDisconnected() {
func (c *GrpcClient) notifyStreamConnected() { func (c *GrpcClient) notifyStreamConnected() {
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()
c.status = StreamConnected c.status = StreamConnected
if c.connectedCh != nil { if c.connectedCh != nil {
// there are goroutines waiting on this channel -> release them // there are goroutines waiting on this channel -> release them
close(c.connectedCh) close(c.connectedCh)
c.connectedCh = nil c.connectedCh = nil
} }
if c.onReconnectedListenerFn != nil {
c.onReconnectedListenerFn()
}
} }
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} { func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
@ -271,6 +278,13 @@ func (c *GrpcClient) WaitStreamConnected() {
} }
} }
func (c *GrpcClient) SetOnReconnectedListener(fn func()) {
c.mux.Lock()
defer c.mux.Unlock()
c.onReconnectedListenerFn = fn
}
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
// The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
// GrpcClient.connWg can be used to wait // GrpcClient.connWg can be used to wait

View File

@ -15,6 +15,12 @@ type MockClient struct {
ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error
SendToStreamFunc func(msg *proto.EncryptedMessage) error SendToStreamFunc func(msg *proto.EncryptedMessage) error
SendFunc func(msg *proto.Message) error SendFunc func(msg *proto.Message) error
SetOnReconnectedListenerFunc func(f func())
}
// SetOnReconnectedListener sets the function to be called when the client reconnects.
func (sm *MockClient) SetOnReconnectedListener(_ func()) {
// Do nothing
} }
func (sm *MockClient) IsHealthy() bool { func (sm *MockClient) IsHealthy() bool {