[client, relay-server] Feature/relay notification (#4083)

- Clients now subscribe to peer status changes.
- The server manages and maintains these subscriptions.
- Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity.
This commit is contained in:
Zoltan Papp
2025-07-15 10:43:42 +02:00
committed by GitHub
parent e49bcc343d
commit 0dab03252c
39 changed files with 1464 additions and 495 deletions

View File

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
type ProxyBind struct { type ProxyBind struct {
@ -28,6 +29,17 @@ type ProxyBind struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
}
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
p := &ProxyBind{
Bind: bind,
closeListener: listener.NewCloseListener(),
}
return p
} }
// AddTurnConn adds a new connection to the bind. // AddTurnConn adds a new connection to the bind.
@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
} }
} }
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyBind) Work() { func (p *ProxyBind) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
@ -96,6 +112,9 @@ func (p *ProxyBind) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closeListener.SetCloseListener(nil)
p.closed = true p.closed = true
p.cancel() p.cancel()
@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
p.closeListener.Notify()
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }

View File

@ -11,6 +11,8 @@ import (
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@ -26,6 +28,15 @@ type ProxyWrapper struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
closeListener: listener.NewCloseListener(),
}
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr return p.wgEndpointAddr
} }
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyWrapper) Work() { func (p *ProxyWrapper) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
@ -77,6 +92,8 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel() e.cancel()
e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err) return fmt.Errorf("failed to close remote conn: %w", err)
} }
@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
if ctx.Err() != nil { if ctx.Err() != nil {
return 0, ctx.Err() return 0, ctx.Err()
} }
p.closeListener.Notify()
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
} }

View File

@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort) return udpProxy.NewWGUDPProxy(w.wgPort)
} }
return &ebpf.ProxyWrapper{ return ebpf.NewProxyWrapper(w.ebpfProxy)
WgeBPFProxy: w.ebpfProxy,
}
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
} }
func (w *USPFactory) GetProxy() Proxy { func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{ return proxyBind.NewProxyBind(w.bind)
Bind: w.bind,
}
} }
func (w *USPFactory) Free() error { func (w *USPFactory) Free() error {

View File

@ -0,0 +1,19 @@
package listener
type CloseListener struct {
listener func()
}
func NewCloseListener() *CloseListener {
return &CloseListener{}
}
func (c *CloseListener) SetCloseListener(listener func()) {
c.listener = listener
}
func (c *CloseListener) Notify() {
if c.listener != nil {
c.listener()
}
}

View File

@ -12,4 +12,5 @@ type Proxy interface {
Work() // Work start or resume the proxy Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error CloseConn() error
SetDisconnectListener(disconnected func())
} }

View File

@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
t.Errorf("failed to free ebpf proxy: %s", err) t.Errorf("failed to free ebpf proxy: %s", err)
} }
}() }()
proxyWrapper := &ebpf.ProxyWrapper{ proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct { tests = append(tests, struct {
name string name string

View File

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors" cerrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// WGUDPProxy proxies // WGUDPProxy proxies
@ -28,6 +29,8 @@ type WGUDPProxy struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
} }
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
closeListener: listener.NewCloseListener(),
} }
return p return p
} }
@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
return endpointUdpAddr return endpointUdpAddr
} }
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
// Work starts the proxy or resumes it if it was paused // Work starts the proxy or resumes it if it was paused
func (p *WGUDPProxy) Work() { func (p *WGUDPProxy) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closeListener.SetCloseListener(nil)
p.closed = true p.closed = true
p.cancel() p.cancel()
@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
p.closeListener.Notify()
log.Debugf("failed to read from wg interface conn: %s", err) log.Debugf("failed to read from wg interface conn: %s", err)
return return
} }

View File

@ -167,7 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@ -489,6 +489,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
wgProxy.SetDisconnectListener(conn.onRelayDisconnected)
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())

View File

@ -19,6 +19,7 @@ type RelayConnInfo struct {
} }
type WorkerRelay struct { type WorkerRelay struct {
peerCtx context.Context
log *log.Entry log *log.Entry
isController bool isController bool
config ConnConfig config ConnConfig
@ -33,8 +34,9 @@ type WorkerRelay struct {
wgWatcher *WGWatcher wgWatcher *WGWatcher
} }
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
peerCtx: ctx,
log: log, log: log,
isController: ctrl, isController: ctrl,
config: config, config: config,
@ -62,7 +64,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
if err != nil { if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) { if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection") w.log.Debugf("handled offer by reusing existing relay connection")

View File

@ -7,13 +7,6 @@ import (
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
) )
// Validator is an interface that defines the Validate method.
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
type TimedHMACValidator struct { type TimedHMACValidator struct {
authenticatorV2 *authv2.Validator authenticatorV2 *authv2.Validator
authenticator *auth.TimedHMACValidator authenticator *auth.TimedHMACValidator

View File

@ -124,15 +124,14 @@ func (cc *connContainer) close() {
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server. // While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
type Client struct { type Client struct {
log *log.Entry log *log.Entry
parentCtx context.Context
connectionURL string connectionURL string
authTokenStore *auth.TokenStore authTokenStore *auth.TokenStore
hashedID []byte hashedID messages.PeerID
bufPool *sync.Pool bufPool *sync.Pool
relayConn net.Conn relayConn net.Conn
conns map[string]*connContainer conns map[messages.PeerID]*connContainer
serviceIsRunning bool serviceIsRunning bool
mu sync.Mutex // protect serviceIsRunning and conns mu sync.Mutex // protect serviceIsRunning and conns
readLoopMutex sync.Mutex readLoopMutex sync.Mutex
@ -142,14 +141,17 @@ type Client struct {
onDisconnectListener func(string) onDisconnectListener func(string)
listenerMutex sync.Mutex listenerMutex sync.Mutex
stateSubscription *PeersStateSubscription
} }
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID) hashedID := messages.HashID(peerID)
relayLog := log.WithFields(log.Fields{"relay": serverURL})
c := &Client{ c := &Client{
log: log.WithFields(log.Fields{"relay": serverURL}), log: relayLog,
parentCtx: ctx,
connectionURL: serverURL, connectionURL: serverURL,
authTokenStore: authTokenStore, authTokenStore: authTokenStore,
hashedID: hashedID, hashedID: hashedID,
@ -159,14 +161,15 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
return &buf return &buf
}, },
}, },
conns: make(map[string]*connContainer), conns: make(map[messages.PeerID]*connContainer),
} }
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID)
return c return c
} }
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error { func (c *Client) Connect(ctx context.Context) error {
c.log.Infof("connecting to relay server") c.log.Infof("connecting to relay server")
c.readLoopMutex.Lock() c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock() defer c.readLoopMutex.Unlock()
@ -178,17 +181,23 @@ func (c *Client) Connect() error {
return nil return nil
} }
if err := c.connect(); err != nil { if err := c.connect(ctx); err != nil {
return err return err
} }
c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
c.log = c.log.WithField("relay", c.instanceURL.String()) c.log = c.log.WithField("relay", c.instanceURL.String())
c.log.Infof("relay connection established") c.log.Infof("relay connection established")
c.serviceIsRunning = true c.serviceIsRunning = true
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag)
c.wgReadLoop.Add(1) c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn) go c.readLoop(hc, c.relayConn, internallyStoppedFlag)
return nil return nil
} }
@ -196,26 +205,41 @@ func (c *Client) Connect() error {
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress // OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
// to the relay server, the function will block until the connection is established or timed out. Otherwise, // to the relay server, the function will block until the connection is established or timed out. Otherwise,
// it will return immediately. // it will return immediately.
// It block until the server confirm the peer is online.
// todo: what should happen if call with the same peerID with multiple times? // todo: what should happen if call with the same peerID with multiple times?
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
c.mu.Lock() peerID := messages.HashID(dstPeerID)
defer c.mu.Unlock()
c.mu.Lock()
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established") return nil, fmt.Errorf("relay connection is not established")
} }
_, ok := c.conns[peerID]
hashedID, hashedStringID := messages.HashID(dstPeerID)
_, ok := c.conns[hashedStringID]
if ok { if ok {
c.mu.Unlock()
return nil, ErrConnAlreadyExists return nil, ErrConnAlreadyExists
} }
c.mu.Unlock()
c.log.Infof("open connection to peer: %s", hashedStringID) if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
c.log.Errorf("peer not available: %s, %s", peerID, err)
return nil, err
}
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
msgChannel := make(chan Msg, 100) msgChannel := make(chan Msg, 100)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) conn := NewConn(c, peerID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel) c.mu.Lock()
_, ok = c.conns[peerID]
if ok {
c.mu.Unlock()
_ = conn.Close()
return nil, ErrConnAlreadyExists
}
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
c.mu.Unlock()
return conn, nil return conn, nil
} }
@ -254,7 +278,7 @@ func (c *Client) Close() error {
return c.close(true) return c.close(true)
} }
func (c *Client) connect() error { func (c *Client) connect(ctx context.Context) error {
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial() conn, err := rd.Dial()
if err != nil { if err != nil {
@ -262,7 +286,7 @@ func (c *Client) connect() error {
} }
c.relayConn = conn c.relayConn = conn
if err = c.handShake(); err != nil { if err = c.handShake(ctx); err != nil {
cErr := conn.Close() cErr := conn.Close()
if cErr != nil { if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr) c.log.Errorf("failed to close connection: %s", cErr)
@ -273,7 +297,7 @@ func (c *Client) connect() error {
return nil return nil
} }
func (c *Client) handShake() error { func (c *Client) handShake(ctx context.Context) error {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil { if err != nil {
c.log.Errorf("failed to marshal auth message: %s", err) c.log.Errorf("failed to marshal auth message: %s", err)
@ -286,7 +310,7 @@ func (c *Client) handShake() error {
return err return err
} }
buf := make([]byte, messages.MaxHandshakeRespSize) buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf) n, err := c.readWithTimeout(ctx, buf)
if err != nil { if err != nil {
c.log.Errorf("failed to read auth response: %s", err) c.log.Errorf("failed to read auth response: %s", err)
return err return err
@ -319,11 +343,7 @@ func (c *Client) handShake() error {
return nil return nil
} }
func (c *Client) readLoop(relayConn net.Conn) { func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var ( var (
errExit error errExit error
n int n int
@ -370,6 +390,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.instanceURL = nil c.instanceURL = nil
c.muInstanceURL.Unlock() c.muInstanceURL.Unlock()
c.stateSubscription.Cleanup()
c.wgReadLoop.Done() c.wgReadLoop.Done()
_ = c.close(false) _ = c.close(false)
c.notifyDisconnected() c.notifyDisconnected()
@ -382,6 +403,14 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
case messages.MsgTypeTransport: case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypePeersOnline:
c.handlePeersOnlineMsg(buf)
c.bufPool.Put(bufPtr)
return true
case messages.MsgTypePeersWentOffline:
c.handlePeersWentOfflineMsg(buf)
c.bufPool.Put(bufPtr)
return true
case messages.MsgTypeClose: case messages.MsgTypeClose:
c.log.Debugf("relay connection close by server") c.log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
@ -413,18 +442,16 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
return true return true
} }
stringID := messages.HashIDToString(peerID)
c.mu.Lock() c.mu.Lock()
if !c.serviceIsRunning { if !c.serviceIsRunning {
c.mu.Unlock() c.mu.Unlock()
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
return false return false
} }
container, ok := c.conns[stringID] container, ok := c.conns[*peerID]
c.mu.Unlock() c.mu.Unlock()
if !ok { if !ok {
c.log.Errorf("peer not found: %s", stringID) c.log.Errorf("peer not found: %s", peerID.String())
c.bufPool.Put(bufPtr) c.bufPool.Put(bufPtr)
return true return true
} }
@ -437,9 +464,9 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
return true return true
} }
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) { func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) {
c.mu.Lock() c.mu.Lock()
conn, ok := c.conns[id] conn, ok := c.conns[dstID]
c.mu.Unlock() c.mu.Unlock()
if !ok { if !ok {
return 0, net.ErrClosed return 0, net.ErrClosed
@ -464,7 +491,7 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
return len(payload), err return len(payload), err
} }
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) { func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for { for {
select { select {
case _, ok := <-hc.OnTimeout: case _, ok := <-hc.OnTimeout:
@ -478,7 +505,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
c.log.Warnf("failed to close connection: %s", err) c.log.Warnf("failed to close connection: %s", err)
} }
return return
case <-c.parentCtx.Done(): case <-ctx.Done():
err := c.close(true) err := c.close(true)
if err != nil { if err != nil {
c.log.Errorf("failed to teardown connection: %s", err) c.log.Errorf("failed to teardown connection: %s", err)
@ -492,10 +519,31 @@ func (c *Client) closeAllConns() {
for _, container := range c.conns { for _, container := range c.conns {
container.close() container.close()
} }
c.conns = make(map[string]*connContainer) c.conns = make(map[messages.PeerID]*connContainer)
} }
func (c *Client) closeConn(connReference *Conn, id string) error { func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) {
c.mu.Lock()
defer c.mu.Unlock()
for _, peerID := range peerIDs {
container, ok := c.conns[peerID]
if !ok {
c.log.Warnf("can not close connection, peer not found: %s", peerID)
continue
}
container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID)
container.close()
delete(c.conns, peerID)
}
if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil {
c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err)
}
}
func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -507,6 +555,11 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference { if container.conn != connReference {
return fmt.Errorf("conn reference mismatch") return fmt.Errorf("conn reference mismatch")
} }
if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil {
container.log.Errorf("failed to unsubscribe from peer state change: %s", err)
}
c.log.Infof("free up connection to peer: %s", id) c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id) delete(c.conns, id)
container.close() container.close()
@ -559,8 +612,8 @@ func (c *Client) writeCloseMsg() {
} }
} }
func (c *Client) readWithTimeout(buf []byte) (int, error) { func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout) ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout)
defer cancel() defer cancel()
readDone := make(chan struct{}) readDone := make(chan struct{})
@ -581,3 +634,21 @@ func (c *Client) readWithTimeout(buf []byte) (int, error) {
return n, err return n, err
} }
} }
func (c *Client) handlePeersOnlineMsg(buf []byte) {
peersID, err := messages.UnmarshalPeersOnlineMsg(buf)
if err != nil {
c.log.Errorf("failed to unmarshal peers online msg: %s", err)
return
}
c.stateSubscription.OnPeersOnline(peersID)
}
func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
peersID, err := messages.UnMarshalPeersWentOffline(buf)
if err != nil {
c.log.Errorf("failed to unmarshal peers went offline msg: %s", err)
return
}
c.stateSubscription.OnPeersWentOffline(peersID)
}

View File

@ -18,14 +18,19 @@ import (
) )
var ( var (
av = &allow.Auth{}
hmacTokenStore = &hmac.TokenStore{} hmacTokenStore = &hmac.TokenStore{}
serverListenAddr = "127.0.0.1:1234" serverListenAddr = "127.0.0.1:1234"
serverURL = "rel://127.0.0.1:1234" serverURL = "rel://127.0.0.1:1234"
serverCfg = server.Config{
Meter: otel.Meter(""),
ExposedAddress: serverURL,
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
_ = util.InitLog("error", "console") _ = util.InitLog("debug", "console")
code := m.Run() code := m.Run()
os.Exit(code) os.Exit(code)
} }
@ -33,7 +38,7 @@ func TestMain(m *testing.M) {
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -58,37 +63,37 @@ func TestClient(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
t.Log("alice connecting to server") t.Log("alice connecting to server")
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientAlice.Close() defer clientAlice.Close()
t.Log("placeholder connecting to server") t.Log("placeholder connecting to server")
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder") clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder")
err = clientPlaceHolder.Connect() err = clientPlaceHolder.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientPlaceHolder.Close() defer clientPlaceHolder.Close()
t.Log("Bob connecting to server") t.Log("Bob connecting to server")
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
defer clientBob.Close() defer clientBob.Close()
t.Log("Alice open connection to Bob") t.Log("Alice open connection to Bob")
connAliceToBob, err := clientAlice.OpenConn("bob") connAliceToBob, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
t.Log("Bob open connection to Alice") t.Log("Bob open connection to Alice")
connBobToAlice, err := clientBob.OpenConn("alice") connBobToAlice, err := clientBob.OpenConn(ctx, "alice")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -115,7 +120,7 @@ func TestClient(t *testing.T) {
func TestRegistration(t *testing.T) { func TestRegistration(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -132,8 +137,8 @@ func TestRegistration(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
_ = srv.Shutdown(ctx) _ = srv.Shutdown(ctx)
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
@ -172,8 +177,8 @@ func TestRegistrationTimeout(t *testing.T) {
_ = fakeTCPListener.Close() _ = fakeTCPListener.Close()
}(fakeTCPListener) }(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice") clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err == nil { if err == nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
@ -189,7 +194,7 @@ func TestEcho(t *testing.T) {
idAlice := "alice" idAlice := "alice"
idBob := "bob" idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -213,8 +218,8 @@ func TestEcho(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@ -225,8 +230,8 @@ func TestEcho(t *testing.T) {
} }
}() }()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@ -237,12 +242,12 @@ func TestEcho(t *testing.T) {
} }
}() }()
connAliceToBob, err := clientAlice.OpenConn(idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -278,7 +283,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -303,14 +308,14 @@ func TestBindToUnavailabePeer(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err != nil { if err == nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("expected error when binding to unavailable peer, got nil")
} }
log.Infof("closing client") log.Infof("closing client")
@ -324,7 +329,7 @@ func TestBindReconnect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -349,24 +354,24 @@ func TestBindReconnect(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") chBob, err := clientBob.OpenConn(ctx, "alice")
err = clientBob.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chBob, err := clientBob.OpenConn("alice")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@ -377,18 +382,28 @@ func TestBindReconnect(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice") clientAlice = NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
chAlice, err := clientAlice.OpenConn("bob") chAlice, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
testString := "hello alice, I am bob" testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString))
if err == nil {
t.Errorf("expected error when writing to channel, got nil")
}
chBob, err = clientBob.OpenConn(ctx, "alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_, err = chBob.Write([]byte(testString)) _, err = chBob.Write([]byte(testString))
if err != nil { if err != nil {
t.Errorf("failed to write to channel: %s", err) t.Errorf("failed to write to channel: %s", err)
@ -415,7 +430,7 @@ func TestCloseConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -440,13 +455,19 @@ func TestCloseConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") bob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientAlice.Connect() err = bob.Connect(ctx)
if err != nil { if err != nil {
t.Errorf("failed to connect to server: %s", err) t.Errorf("failed to connect to server: %s", err)
} }
conn, err := clientAlice.OpenConn("bob") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@ -472,7 +493,7 @@ func TestCloseRelayConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -496,13 +517,19 @@ func TestCloseRelayConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") bob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientAlice.Connect() err = bob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
conn, err := clientAlice.OpenConn("bob") clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@ -514,7 +541,7 @@ func TestCloseRelayConn(t *testing.T) {
t.Errorf("unexpected reading from closed connection") t.Errorf("unexpected reading from closed connection")
} }
_, err = clientAlice.OpenConn("bob") _, err = clientAlice.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@ -524,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -544,8 +571,8 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect() err = relayClient.Connect(ctx)
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
@ -567,7 +594,7 @@ func TestCloseByServer(t *testing.T) {
log.Fatalf("timeout waiting for client to disconnect") log.Fatalf("timeout waiting for client to disconnect")
} }
_, err = relayClient.OpenConn("bob") _, err = relayClient.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@ -577,7 +604,7 @@ func TestCloseByClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -596,8 +623,8 @@ func TestCloseByClient(t *testing.T) {
idAlice := "alice" idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect() err = relayClient.Connect(ctx)
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
@ -607,7 +634,7 @@ func TestCloseByClient(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
_, err = relayClient.OpenConn("bob") _, err = relayClient.OpenConn(ctx, "bob")
if err == nil { if err == nil {
t.Errorf("unexpected opening connection to closed server") t.Errorf("unexpected opening connection to closed server")
} }
@ -623,7 +650,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
idAlice := "alice" idAlice := "alice"
idBob := "bob" idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr} srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -647,8 +674,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect() err = clientAlice.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@ -659,8 +686,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
} }
}() }()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect() err = clientBob.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@ -671,12 +698,12 @@ func TestCloseNotDrainedChannel(t *testing.T) {
} }
}() }()
connAliceToBob, err := clientAlice.OpenConn(idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }

View File

@ -3,13 +3,14 @@ package client
import ( import (
"net" "net"
"time" "time"
"github.com/netbirdio/netbird/relay/messages"
) )
// Conn represent a connection to a relayed remote peer. // Conn represent a connection to a relayed remote peer.
type Conn struct { type Conn struct {
client *Client client *Client
dstID []byte dstID messages.PeerID
dstStringID string
messageChan chan Msg messageChan chan Msg
instanceURL *RelayAddr instanceURL *RelayAddr
} }
@ -17,14 +18,12 @@ type Conn struct {
// NewConn creates a new connection to a relayed remote peer. // NewConn creates a new connection to a relayed remote peer.
// client: the client instance, it used to send messages to the destination peer // client: the client instance, it used to send messages to the destination peer
// dstID: the destination peer ID // dstID: the destination peer ID
// dstStringID: the destination peer ID in string format
// messageChan: the channel where the messages will be received // messageChan: the channel where the messages will be received
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer // instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn { func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
c := &Conn{ c := &Conn{
client: client, client: client,
dstID: dstID, dstID: dstID,
dstStringID: dstStringID,
messageChan: messageChan, messageChan: messageChan,
instanceURL: instanceURL, instanceURL: instanceURL,
} }
@ -33,7 +32,7 @@ func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan
} }
func (c *Conn) Write(p []byte) (n int, err error) { func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c, c.dstStringID, c.dstID, p) return c.client.writeTo(c, c.dstID, p)
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
@ -48,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
return c.client.closeConn(c, c.dstStringID) return c.client.closeConn(c, c.dstID)
} }
func (c *Conn) LocalAddr() net.Addr { func (c *Conn) LocalAddr() net.Addr {

View File

@ -80,7 +80,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
if err := rc.Connect(); err != nil { if err := rc.Connect(parentCtx); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err) log.Errorf("failed to reconnect to relay server: %s", err)
return false return false
} }

View File

@ -42,7 +42,7 @@ type OnServerCloseListener func()
// ManagerService is the interface for the relay manager. // ManagerService is the interface for the relay manager.
type ManagerService interface { type ManagerService interface {
Serve() error Serve() error
OpenConn(serverAddress, peerKey string) (net.Conn, error) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error) RelayInstanceAddress() (string, error)
ServerURLs() []string ServerURLs() []string
@ -123,7 +123,7 @@ func (m *Manager) Serve() error {
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new // established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
m.relayClientMu.Lock() m.relayClientMu.Lock()
defer m.relayClientMu.Unlock() defer m.relayClientMu.Unlock()
@ -141,10 +141,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
) )
if !foreign { if !foreign {
log.Debugf("open peer connection via permanent server: %s", peerKey) log.Debugf("open peer connection via permanent server: %s", peerKey)
netConn, err = m.relayClient.OpenConn(peerKey) netConn, err = m.relayClient.OpenConn(ctx, peerKey)
} else { } else {
log.Debugf("open peer connection via foreign server: %s", serverAddress) log.Debugf("open peer connection via foreign server: %s", serverAddress)
netConn, err = m.openConnVia(serverAddress, peerKey) netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -229,7 +229,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error {
return m.tokenStore.UpdateToken(token) return m.tokenStore.UpdateToken(token)
} }
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
// check if already has a connection to the desired relay server // check if already has a connection to the desired relay server
m.relayClientsMutex.RLock() m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress] rt, ok := m.relayClients[serverAddress]
@ -240,7 +240,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
if rt.err != nil { if rt.err != nil {
return nil, rt.err return nil, rt.err
} }
return rt.relayClient.OpenConn(peerKey) return rt.relayClient.OpenConn(ctx, peerKey)
} }
m.relayClientsMutex.RUnlock() m.relayClientsMutex.RUnlock()
@ -255,7 +255,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
if rt.err != nil { if rt.err != nil {
return nil, rt.err return nil, rt.err
} }
return rt.relayClient.OpenConn(peerKey) return rt.relayClient.OpenConn(ctx, peerKey)
} }
// create a new relay client and store it in the relayClients map // create a new relay client and store it in the relayClients map
@ -264,8 +264,8 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
m.relayClients[serverAddress] = rt m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock() m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID) relayClient := NewClient(serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect() err := relayClient.Connect(m.ctx)
if err != nil { if err != nil {
rt.err = err rt.err = err
rt.Unlock() rt.Unlock()
@ -279,7 +279,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
rt.relayClient = relayClient rt.relayClient = relayClient
rt.Unlock() rt.Unlock()
conn, err := relayClient.OpenConn(peerKey) conn, err := relayClient.OpenConn(ctx, peerKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -8,6 +8,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/auth/allow"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
) )
@ -22,16 +23,22 @@ func TestEmptyURL(t *testing.T) {
func TestForeignConn(t *testing.T) { func TestForeignConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg1 := server.ListenerConfig{ lstCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
srv1, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: lstCfg1.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv1.Listen(srvCfg1) err := srv1.Listen(lstCfg1)
if err != nil { if err != nil {
errChan <- err errChan <- err
} }
@ -51,7 +58,12 @@ func TestForeignConn(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: srvCfg2.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -74,32 +86,26 @@ func TestForeignConn(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice")
err = clientAlice.Serve() if err := clientAlice.Serve(); err != nil {
if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
idBob := "bob" clientBob := NewManager(mCtx, toURL(srvCfg2), "bob")
log.Debugf("connect by bob") if err := clientBob.Serve(); err != nil {
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
bobsSrvAddr, err := clientBob.RelayInstanceAddress() bobsSrvAddr, err := clientBob.RelayInstanceAddress()
if err != nil { if err != nil {
t.Fatalf("failed to get relay address: %s", err) t.Fatalf("failed to get relay address: %s", err)
} }
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob) connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice) connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -137,7 +143,7 @@ func TestForeginConnClose(t *testing.T) {
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -163,7 +169,7 @@ func TestForeginConnClose(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -186,16 +192,20 @@ func TestForeginConnClose(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob")
if err := mgrBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
mgr := NewManager(mCtx, toURL(srvCfg1), "alice")
err = mgr.Serve() err = mgr.Serve()
if err != nil { if err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -212,7 +222,7 @@ func TestForeginAutoClose(t *testing.T) {
srvCfg1 := server.ListenerConfig{ srvCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv1, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -241,7 +251,7 @@ func TestForeginAutoClose(t *testing.T) {
srvCfg2 := server.ListenerConfig{ srvCfg2 := server.ListenerConfig{
Address: "localhost:2234", Address: "localhost:2234",
} }
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) srv2, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -277,7 +287,7 @@ func TestForeginAutoClose(t *testing.T) {
} }
t.Log("open connection to another peer") t.Log("open connection to another peer")
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
@ -305,7 +315,7 @@ func TestAutoReconnect(t *testing.T) {
srvCfg := server.ListenerConfig{ srvCfg := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av) srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -330,6 +340,13 @@ func TestAutoReconnect(t *testing.T) {
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientBob := NewManager(mCtx, toURL(srvCfg), "bob")
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice") clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
err = clientAlice.Serve() err = clientAlice.Serve()
if err != nil { if err != nil {
@ -339,7 +356,7 @@ func TestAutoReconnect(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("failed to get relay address: %s", err) t.Errorf("failed to get relay address: %s", err)
} }
conn, err := clientAlice.OpenConn(ra, "bob") conn, err := clientAlice.OpenConn(ctx, ra, "bob")
if err != nil { if err != nil {
t.Errorf("failed to bind channel: %s", err) t.Errorf("failed to bind channel: %s", err)
} }
@ -357,7 +374,7 @@ func TestAutoReconnect(t *testing.T) {
time.Sleep(reconnectingTimeout + 1*time.Second) time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection") log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra, "bob") _, err = clientAlice.OpenConn(ctx, ra, "bob")
if err != nil { if err != nil {
t.Errorf("failed to open channel: %s", err) t.Errorf("failed to open channel: %s", err)
} }
@ -366,24 +383,27 @@ func TestAutoReconnect(t *testing.T) {
func TestNotifierDoubleAdd(t *testing.T) { func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srvCfg1 := server.ListenerConfig{ listenerCfg1 := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
} }
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) srv, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: listenerCfg1.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
err := srv1.Listen(srvCfg1) if err := srv.Listen(listenerCfg1); err != nil {
if err != nil {
errChan <- err errChan <- err
} }
}() }()
defer func() { defer func() {
err := srv1.Shutdown(ctx) if err := srv.Shutdown(ctx); err != nil {
if err != nil {
t.Errorf("failed to close server: %s", err) t.Errorf("failed to close server: %s", err)
} }
}() }()
@ -392,17 +412,21 @@ func TestNotifierDoubleAdd(t *testing.T) {
t.Fatalf("failed to start server: %s", err) t.Fatalf("failed to start server: %s", err)
} }
idAlice := "alice"
log.Debugf("connect by alice") log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx) mCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = clientAlice.Serve() clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob")
if err != nil { if err = clientBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err) t.Fatalf("failed to serve manager: %s", err)
} }
conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob") clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice")
if err = clientAlice.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }

View File

@ -0,0 +1,168 @@
package client
import (
"context"
"errors"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
)
const (
OpenConnectionTimeout = 30 * time.Second
)
type relayedConnWriter interface {
Write(p []byte) (n int, err error)
}
// PeersStateSubscription manages subscriptions to peer state changes (online/offline)
// over a relay connection. It allows tracking peers' availability and handling offline
// events via a callback. We get online notification from the server only once.
type PeersStateSubscription struct {
log *log.Entry
relayConn relayedConnWriter
offlineCallback func(peerIDs []messages.PeerID)
listenForOfflinePeers map[messages.PeerID]struct{}
waitingPeers map[messages.PeerID]chan struct{}
}
func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription {
return &PeersStateSubscription{
log: log,
relayConn: relayConn,
offlineCallback: offlineCallback,
listenForOfflinePeers: make(map[messages.PeerID]struct{}),
waitingPeers: make(map[messages.PeerID]chan struct{}),
}
}
// OnPeersOnline should be called when a notification is received that certain peers have come online.
// It checks if any of the peers are being waited on and signals their availability.
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
for _, peerID := range peersID {
waitCh, ok := s.waitingPeers[peerID]
if !ok {
continue
}
close(waitCh)
delete(s.waitingPeers, peerID)
}
}
func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
relevantPeers := make([]messages.PeerID, 0, len(peersID))
for _, peerID := range peersID {
if _, ok := s.listenForOfflinePeers[peerID]; ok {
relevantPeers = append(relevantPeers, peerID)
}
}
if len(relevantPeers) > 0 {
s.offlineCallback(relevantPeers)
}
}
// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
// todo: when we unsubscribe while this is running, this will not return with error
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
// Check if already waiting for this peer
if _, exists := s.waitingPeers[peerID]; exists {
return errors.New("already waiting for peer to come online")
}
// Create a channel to wait for the peer to come online
waitCh := make(chan struct{})
s.waitingPeers[peerID] = waitCh
if err := s.subscribeStateChange([]messages.PeerID{peerID}); err != nil {
s.log.Errorf("failed to subscribe to peer state: %s", err)
close(waitCh)
delete(s.waitingPeers, peerID)
return err
}
defer func() {
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
}
}()
// Wait for peer to come online or context to be cancelled
timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout)
defer cancel()
select {
case <-waitCh:
s.log.Debugf("peer %s is now online", peerID)
return nil
case <-timeoutCtx.Done():
s.log.Debugf("context timed out while waiting for peer %s to come online", peerID)
if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil {
s.log.Errorf("failed to unsubscribe from peer state: %s", err)
}
return timeoutCtx.Err()
}
}
func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
msgErr := s.unsubscribeStateChange(peerIDs)
for _, peerID := range peerIDs {
if wch, ok := s.waitingPeers[peerID]; ok {
close(wch)
delete(s.waitingPeers, peerID)
}
delete(s.listenForOfflinePeers, peerID)
}
return msgErr
}
func (s *PeersStateSubscription) Cleanup() {
for _, waitCh := range s.waitingPeers {
close(waitCh)
}
s.waitingPeers = make(map[messages.PeerID]chan struct{})
s.listenForOfflinePeers = make(map[messages.PeerID]struct{})
}
func (s *PeersStateSubscription) subscribeStateChange(peerIDs []messages.PeerID) error {
msgs, err := messages.MarshalSubPeerStateMsg(peerIDs)
if err != nil {
return err
}
for _, peer := range peerIDs {
s.listenForOfflinePeers[peer] = struct{}{}
}
for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
return err
}
}
return nil
}
func (s *PeersStateSubscription) unsubscribeStateChange(peerIDs []messages.PeerID) error {
msgs, err := messages.MarshalUnsubPeerStateMsg(peerIDs)
if err != nil {
return err
}
var connWriteErr error
for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
connWriteErr = err
}
}
return connWriteErr
}

View File

@ -0,0 +1,99 @@
package client
import (
"bytes"
"context"
"testing"
"time"
"github.com/netbirdio/netbird/relay/messages"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockRelayedConn struct {
}
func (m *mockRelayedConn) Write(p []byte) (n int, err error) {
return len(p), nil
}
func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) {
peerID := messages.HashID("peer1")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{}) // discard log output
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Launch wait in background
go func() {
time.Sleep(100 * time.Millisecond)
sub.OnPeersOnline([]messages.PeerID{peerID})
}()
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
assert.NoError(t, err)
}
func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) {
peerID := messages.HashID("peer2")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
}
func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) {
peerID := messages.HashID("peer3")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx := context.Background()
go func() {
_ = sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
}()
time.Sleep(100 * time.Millisecond)
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "already waiting")
}
func TestUnsubscribeStateChange(t *testing.T) {
peerID := messages.HashID("peer4")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
doneChan := make(chan struct{})
go func() {
_ = sub.WaitToBeOnlineAndSubscribe(context.Background(), peerID)
close(doneChan)
}()
time.Sleep(100 * time.Millisecond)
err := sub.UnsubscribeStateChange([]messages.PeerID{peerID})
assert.NoError(t, err)
select {
case <-doneChan:
case <-time.After(200 * time.Millisecond):
// Expected timeout, meaning the subscription was successfully unsubscribed
t.Errorf("timeout")
}
}

View File

@ -70,8 +70,8 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url) log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID) relayClient := NewClient(url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect() err := relayClient.Connect(ctx)
resultChan <- connResult{ resultChan <- connResult{
RelayClient: relayClient, RelayClient: relayClient,
Url: url, Url: url,

View File

@ -141,7 +141,14 @@ func execute(cmd *cobra.Command, args []string) error {
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) cfg := server.Config{
Meter: metricsServer.Meter,
ExposedAddress: cobraConfig.ExposedAddress,
AuthValidator: authenticator,
TLSSupport: tlsSupport,
}
srv, err := server.NewServer(cfg)
if err != nil { if err != nil {
log.Debugf("failed to create relay server: %v", err) log.Debugf("failed to create relay server: %v", err)
return fmt.Errorf("failed to create relay server: %v", err) return fmt.Errorf("failed to create relay server: %v", err)

View File

@ -8,24 +8,24 @@ import (
const ( const (
prefixLength = 4 prefixLength = 4
IDSize = prefixLength + sha256.Size peerIDSize = prefixLength + sha256.Size
) )
var ( var (
prefix = []byte("sha-") // 4 bytes prefix = []byte("sha-") // 4 bytes
) )
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string type PeerID [peerIDSize]byte
func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID)) func (p PeerID) String() string {
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:]) return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:]))
var prefixedHash []byte
prefixedHash = append(prefixedHash, prefix...)
prefixedHash = append(prefixedHash, idHash[:]...)
return prefixedHash, idHashString
} }
// HashIDToString converts a hash to a human-readable string // HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
func HashIDToString(idHash []byte) string { func HashID(peerID string) PeerID {
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:])) idHash := sha256.Sum256([]byte(peerID))
var prefixedHash [peerIDSize]byte
copy(prefixedHash[:prefixLength], prefix)
copy(prefixedHash[prefixLength:], idHash[:])
return prefixedHash
} }

View File

@ -1,13 +0,0 @@
package messages
import (
"testing"
)
func TestHashID(t *testing.T) {
hashedID, hashedStringId := HashID("alice")
enc := HashIDToString(hashedID)
if enc != hashedStringId {
t.Errorf("expected %s, got %s", hashedStringId, enc)
}
}

View File

@ -9,19 +9,26 @@ import (
const ( const (
MaxHandshakeSize = 212 MaxHandshakeSize = 212
MaxHandshakeRespSize = 8192 MaxHandshakeRespSize = 8192
MaxMessageSize = 8820
CurrentProtocolVersion = 1 CurrentProtocolVersion = 1
MsgTypeUnknown MsgType = 0 MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead. // Deprecated: Use MsgTypeAuth instead.
MsgTypeHello MsgType = 1 MsgTypeHello = 1
// Deprecated: Use MsgTypeAuthResponse instead. // Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse MsgType = 2 MsgTypeHelloResponse = 2
MsgTypeTransport MsgType = 3 MsgTypeTransport = 3
MsgTypeClose MsgType = 4 MsgTypeClose = 4
MsgTypeHealthCheck MsgType = 5 MsgTypeHealthCheck = 5
MsgTypeAuth = 6 MsgTypeAuth = 6
MsgTypeAuthResponse = 7 MsgTypeAuthResponse = 7
// Peers state messages
MsgTypeSubscribePeerState = 8
MsgTypeUnsubscribePeerState = 9
MsgTypePeersOnline = 10
MsgTypePeersWentOffline = 11
// base size of the message // base size of the message
sizeOfVersionByte = 1 sizeOfVersionByte = 1
@ -30,17 +37,17 @@ const (
// auth message // auth message
sizeOfMagicByte = 4 sizeOfMagicByte = 4
headerSizeAuth = sizeOfMagicByte + IDSize headerSizeAuth = sizeOfMagicByte + peerIDSize
offsetMagicByte = sizeOfProtoHeader offsetMagicByte = sizeOfProtoHeader
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
// hello message // hello message
headerSizeHello = sizeOfMagicByte + IDSize headerSizeHello = sizeOfMagicByte + peerIDSize
headerSizeHelloResp = 0 headerSizeHelloResp = 0
// transport // transport
headerSizeTransport = IDSize headerSizeTransport = peerIDSize
offsetTransportID = sizeOfProtoHeader offsetTransportID = sizeOfProtoHeader
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
) )
@ -72,6 +79,14 @@ func (m MsgType) String() string {
return "close" return "close"
case MsgTypeHealthCheck: case MsgTypeHealthCheck:
return "health check" return "health check"
case MsgTypeSubscribePeerState:
return "subscribe peer state"
case MsgTypeUnsubscribePeerState:
return "unsubscribe peer state"
case MsgTypePeersOnline:
return "peers online"
case MsgTypePeersWentOffline:
return "peers went offline"
default: default:
return "unknown" return "unknown"
} }
@ -102,7 +117,9 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
MsgTypeAuth, MsgTypeAuth,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck,
MsgTypeSubscribePeerState,
MsgTypeUnsubscribePeerState:
return msgType, nil return msgType, nil
default: default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
@ -122,7 +139,9 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
MsgTypeAuthResponse, MsgTypeAuthResponse,
MsgTypeTransport, MsgTypeTransport,
MsgTypeClose, MsgTypeClose,
MsgTypeHealthCheck: MsgTypeHealthCheck,
MsgTypePeersOnline,
MsgTypePeersWentOffline:
return msgType, nil return msgType, nil
default: default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
@ -135,11 +154,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response. // close the network connection without any response.
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions)) msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
@ -147,7 +162,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID...) msg = append(msg, peerID[:]...)
msg = append(msg, additions...) msg = append(msg, additions...)
return msg, nil return msg, nil
@ -156,7 +171,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
// Deprecated: Use UnmarshalAuthMsg instead. // Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server. // authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) {
if len(msg) < sizeOfProtoHeader+headerSizeHello { if len(msg) < sizeOfProtoHeader+headerSizeHello {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
@ -164,7 +179,9 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
return nil, nil, errors.New("invalid magic header") return nil, nil, errors.New("invalid magic header")
} }
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello])
return &peerID, msg[headerSizeHello:], nil
} }
// Deprecated: Use MarshalAuthResponse instead. // Deprecated: Use MarshalAuthResponse instead.
@ -197,34 +214,33 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
// message is used to authenticate the client with the server. The authentication is done using an HMAC method. // message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response. // close the network connection without any response.
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) {
if len(peerID) != IDSize { if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) return nil, fmt.Errorf("too large auth payload")
} }
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload)) msg := make([]byte, headerTotalSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth) msg[1] = byte(MsgTypeAuth)
copy(msg[sizeOfProtoHeader:], magicHeader) copy(msg[sizeOfProtoHeader:], magicHeader)
copy(msg[offsetAuthPeerID:], peerID[:])
msg = append(msg, peerID...) copy(msg[headerTotalSizeAuth:], authPayload)
msg = append(msg, authPayload...)
return msg, nil return msg, nil
} }
// UnmarshalAuthMsg extracts peerID and the auth payload from the message // UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) {
if len(msg) < headerTotalSizeAuth { if len(msg) < headerTotalSizeAuth {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
// Validate the magic header
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) { if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header") return nil, nil, errors.New("invalid magic header")
} }
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth])
return &peerID, msg[headerTotalSizeAuth:], nil
} }
// MarshalAuthResponse creates a response message to the auth. // MarshalAuthResponse creates a response message to the auth.
@ -268,45 +284,48 @@ func MarshalCloseMsg() []byte {
// MarshalTransportMsg creates a transport message. // MarshalTransportMsg creates a transport message.
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the // The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
// destination peer hashed ID. // destination peer hashed ID.
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) { func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) {
if len(peerID) != IDSize { // todo validate size
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) msg := make([]byte, headerTotalSizeTransport+len(payload))
}
msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion) msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport) msg[1] = byte(MsgTypeTransport)
copy(msg[sizeOfProtoHeader:], peerID) copy(msg[sizeOfProtoHeader:], peerID[:])
msg = append(msg, payload...) copy(msg[sizeOfProtoHeader+peerIDSize:], payload)
return msg, nil return msg, nil
} }
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message. // UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) {
if len(buf) < headerTotalSizeTransport { if len(buf) < headerTotalSizeTransport {
return nil, nil, ErrInvalidMessageLength return nil, nil, ErrInvalidMessageLength
} }
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil const offsetEnd = offsetTransportID + peerIDSize
var peerID PeerID
copy(peerID[:], buf[offsetTransportID:offsetEnd])
return &peerID, buf[headerTotalSizeTransport:], nil
} }
// UnmarshalTransportID extracts the peerID from the transport message. // UnmarshalTransportID extracts the peerID from the transport message.
func UnmarshalTransportID(buf []byte) ([]byte, error) { func UnmarshalTransportID(buf []byte) (*PeerID, error) {
if len(buf) < headerTotalSizeTransport { if len(buf) < headerTotalSizeTransport {
return nil, ErrInvalidMessageLength return nil, ErrInvalidMessageLength
} }
return buf[offsetTransportID:headerTotalSizeTransport], nil
const offsetEnd = offsetTransportID + peerIDSize
var id PeerID
copy(id[:], buf[offsetTransportID:offsetEnd])
return &id, nil
} }
// UpdateTransportMsg updates the peerID in the transport message. // UpdateTransportMsg updates the peerID in the transport message.
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do // With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
// need to allocate a new byte slice. // need to allocate a new byte slice.
func UpdateTransportMsg(msg []byte, peerID []byte) error { func UpdateTransportMsg(msg []byte, peerID PeerID) error {
if len(msg) < offsetTransportID+len(peerID) { if len(msg) < offsetTransportID+peerIDSize {
return ErrInvalidMessageLength return ErrInvalidMessageLength
} }
copy(msg[offsetTransportID:], peerID) copy(msg[offsetTransportID:], peerID[:])
return nil return nil
} }

View File

@ -5,7 +5,7 @@ import (
) )
func TestMarshalHelloMsg(t *testing.T) { func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
msg, err := MarshalHelloMsg(peerID, nil) msg, err := MarshalHelloMsg(peerID, nil)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@ -24,13 +24,13 @@ func TestMarshalHelloMsg(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(receivedPeerID) != string(peerID) { if receivedPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, receivedPeerID) t.Errorf("expected %s, got %s", peerID, receivedPeerID)
} }
} }
func TestMarshalAuthMsg(t *testing.T) { func TestMarshalAuthMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
msg, err := MarshalAuthMsg(peerID, []byte{}) msg, err := MarshalAuthMsg(peerID, []byte{})
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@ -49,7 +49,7 @@ func TestMarshalAuthMsg(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(receivedPeerID) != string(peerID) { if receivedPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, receivedPeerID) t.Errorf("expected %s, got %s", peerID, receivedPeerID)
} }
} }
@ -80,7 +80,7 @@ func TestMarshalAuthResponse(t *testing.T) {
} }
func TestMarshalTransportMsg(t *testing.T) { func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload") payload := []byte("payload")
msg, err := MarshalTransportMsg(peerID, payload) msg, err := MarshalTransportMsg(peerID, payload)
if err != nil { if err != nil {
@ -101,7 +101,7 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("failed to unmarshal transport id: %v", err) t.Fatalf("failed to unmarshal transport id: %v", err)
} }
if string(uPeerID) != string(peerID) { if uPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, uPeerID) t.Errorf("expected %s, got %s", peerID, uPeerID)
} }
@ -110,8 +110,8 @@ func TestMarshalTransportMsg(t *testing.T) {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
if string(id) != string(peerID) { if id.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, id) t.Errorf("expected: '%s', got: '%s'", peerID, id)
} }
if string(respPayload) != string(payload) { if string(respPayload) != string(payload) {

View File

@ -0,0 +1,92 @@
package messages
import (
"fmt"
)
func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState))
}
func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState))
}
func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalPeersOnline(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypePeersOnline))
}
func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline))
}
func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type
func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) {
if len(ids) == 0 {
return nil, fmt.Errorf("no list of peer ids provided")
}
const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize
var messages [][]byte
for i := 0; i < len(ids); i += maxPeersPerMessage {
end := i + maxPeersPerMessage
if end > len(ids) {
end = len(ids)
}
chunk := ids[i:end]
totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize
buf := make([]byte, totalSize)
buf[0] = byte(CurrentProtocolVersion)
buf[1] = msgType
offset := sizeOfProtoHeader
for _, id := range chunk {
copy(buf[offset:], id[:])
offset += peerIDSize
}
messages = append(messages, buf)
}
return messages, nil
}
// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer
func unmarshalPeerIDs(buf []byte) ([]PeerID, error) {
if len(buf) < sizeOfProtoHeader {
return nil, fmt.Errorf("invalid message format")
}
if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 {
return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader)
}
numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize
ids := make([]PeerID, numIDs)
offset := sizeOfProtoHeader
for i := 0; i < numIDs; i++ {
copy(ids[i][:], buf[offset:offset+peerIDSize])
offset += peerIDSize
}
return ids, nil
}

View File

@ -0,0 +1,144 @@
package messages
import (
"bytes"
"testing"
)
const (
testPeerCount = 10
)
// Helper function to generate test PeerIDs
func generateTestPeerIDs(n int) []PeerID {
ids := make([]PeerID, n)
for i := 0; i < n; i++ {
for j := 0; j < peerIDSize; j++ {
ids[i][j] = byte(i + j)
}
}
return ids
}
// Helper function to compare slices of PeerID
func peerIDEqual(a, b []PeerID) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !bytes.Equal(a[i][:], b[i][:]) {
return false
}
}
return true
}
func TestMarshalUnmarshalSubPeerState(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalSubPeerStateMsg(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
decoded, err := UnmarshalSubPeerStateMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalSubPeerState_EmptyInput(t *testing.T) {
_, err := MarshalSubPeerStateMsg([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}
func TestUnmarshalSubPeerState_Invalid(t *testing.T) {
// Too short
_, err := UnmarshalSubPeerStateMsg([]byte{1})
if err == nil {
t.Errorf("expected error for short input")
}
// Misaligned length
buf := make([]byte, sizeOfProtoHeader+1)
_, err = UnmarshalSubPeerStateMsg(buf)
if err == nil {
t.Errorf("expected error for misaligned input")
}
}
func TestMarshalUnmarshalPeersOnline(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalPeersOnline(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
decoded, err := UnmarshalPeersOnlineMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalPeersOnline_EmptyInput(t *testing.T) {
_, err := MarshalPeersOnline([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}
func TestUnmarshalPeersOnline_Invalid(t *testing.T) {
_, err := UnmarshalPeersOnlineMsg([]byte{1})
if err == nil {
t.Errorf("expected error for short input")
}
}
func TestMarshalUnmarshalPeersWentOffline(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalPeersWentOffline(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
// MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline
decoded, err := UnmarshalPeersOnlineMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) {
_, err := MarshalPeersWentOffline([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}

View File

@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address" "github.com/netbirdio/netbird/relay/messages/address"
@ -14,6 +13,12 @@ import (
authmsg "github.com/netbirdio/netbird/relay/messages/auth" authmsg "github.com/netbirdio/netbird/relay/messages/auth"
) )
type Validator interface {
Validate(any) error
// Deprecated: Use Validate instead.
ValidateHelloMsgType(any) error
}
// preparedMsg contains the marshalled success response messages // preparedMsg contains the marshalled success response messages
type preparedMsg struct { type preparedMsg struct {
responseHelloMsg []byte responseHelloMsg []byte
@ -54,14 +59,14 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
type handshake struct { type handshake struct {
conn net.Conn conn net.Conn
validator auth.Validator validator Validator
preparedMsg *preparedMsg preparedMsg *preparedMsg
handshakeMethodAuth bool handshakeMethodAuth bool
peerID string peerID *messages.PeerID
} }
func (h *handshake) handshakeReceive() ([]byte, error) { func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
buf := make([]byte, messages.MaxHandshakeSize) buf := make([]byte, messages.MaxHandshakeSize)
n, err := h.conn.Read(buf) n, err := h.conn.Read(buf)
if err != nil { if err != nil {
@ -80,17 +85,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
} }
var ( var peerID *messages.PeerID
bytePeerID []byte
peerID string
)
switch msgType { switch msgType {
//nolint:staticcheck //nolint:staticcheck
case messages.MsgTypeHello: case messages.MsgTypeHello:
bytePeerID, peerID, err = h.handleHelloMsg(buf) peerID, err = h.handleHelloMsg(buf)
case messages.MsgTypeAuth: case messages.MsgTypeAuth:
h.handshakeMethodAuth = true h.handshakeMethodAuth = true
bytePeerID, peerID, err = h.handleAuthMsg(buf) peerID, err = h.handleAuthMsg(buf)
default: default:
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
} }
@ -98,7 +100,7 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
return nil, err return nil, err
} }
h.peerID = peerID h.peerID = peerID
return bytePeerID, nil return peerID, nil
} }
func (h *handshake) handshakeResponse() error { func (h *handshake) handshakeResponse() error {
@ -116,40 +118,37 @@ func (h *handshake) handshakeResponse() error {
return nil return nil
} }
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { func (h *handshake) handleHelloMsg(buf []byte) (*messages.PeerID, error) {
//nolint:staticcheck //nolint:staticcheck
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) peerID, authData, err := messages.UnmarshalHelloMsg(buf)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err) return nil, fmt.Errorf("unmarshal hello message: %w", err)
} }
peerID := messages.HashIDToString(rawPeerID)
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
authMsg, err := authmsg.UnmarshalMsg(authData) authMsg, err := authmsg.UnmarshalMsg(authData)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal auth message: %w", err) return nil, fmt.Errorf("unmarshal auth message: %w", err)
} }
//nolint:staticcheck //nolint:staticcheck
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) return nil, fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
} }
return rawPeerID, peerID, nil return peerID, nil
} }
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("unmarshal hello message: %w", err) return nil, fmt.Errorf("unmarshal hello message: %w", err)
} }
peerID := messages.HashIDToString(rawPeerID)
if err := h.validator.Validate(authPayload); err != nil { if err := h.validator.Validate(authPayload); err != nil {
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
} }
return rawPeerID, peerID, nil return rawPeerID, nil
} }

View File

@ -12,43 +12,50 @@ import (
"github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
) )
const ( const (
bufferSize = 8820 bufferSize = messages.MaxMessageSize
errCloseConn = "failed to close connection to peer: %s" errCloseConn = "failed to close connection to peer: %s"
) )
// Peer represents a peer connection // Peer represents a peer connection
type Peer struct { type Peer struct {
metrics *metrics.Metrics metrics *metrics.Metrics
log *log.Entry log *log.Entry
idS string id messages.PeerID
idB []byte conn net.Conn
conn net.Conn connMu sync.RWMutex
connMu sync.RWMutex store *store.Store
store *Store notifier *store.PeerNotifier
peersListener *store.Listener
} }
// NewPeer creates a new Peer instance and prepare custom logging // NewPeer creates a new Peer instance and prepare custom logging
func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer { func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
stringID := messages.HashIDToString(id) p := &Peer{
return &Peer{ metrics: metrics,
metrics: metrics, log: log.WithField("peer_id", id.String()),
log: log.WithField("peer_id", stringID), id: id,
idS: stringID, conn: conn,
idB: id, store: store,
conn: conn, notifier: notifier,
store: store,
} }
return p
} }
// Work reads data from the connection // Work reads data from the connection
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly. // the message accordingly.
func (p *Peer) Work() { func (p *Peer) Work() {
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
defer func() { defer func() {
p.notifier.RemoveListener(p.peersListener)
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
p.log.Errorf(errCloseConn, err) p.log.Errorf(errCloseConn, err)
} }
@ -94,6 +101,10 @@ func (p *Peer) Work() {
} }
} }
func (p *Peer) ID() messages.PeerID {
return p.id
}
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) { func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
switch msgType { switch msgType {
case messages.MsgTypeHealthCheck: case messages.MsgTypeHealthCheck:
@ -107,6 +118,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
log.Errorf(errCloseConn, err) log.Errorf(errCloseConn, err)
} }
case messages.MsgTypeSubscribePeerState:
p.handleSubscribePeerState(msg)
case messages.MsgTypeUnsubscribePeerState:
p.handleUnsubscribePeerState(msg)
default: default:
p.log.Warnf("received unexpected message type: %s", msgType) p.log.Warnf("received unexpected message type: %s", msgType)
} }
@ -145,7 +160,7 @@ func (p *Peer) Close() {
// String returns the peer ID // String returns the peer ID
func (p *Peer) String() string { func (p *Peer) String() string {
return p.idS return p.id.String()
} }
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error { func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
@ -197,14 +212,14 @@ func (p *Peer) handleTransportMsg(msg []byte) {
return return
} }
stringPeerID := messages.HashIDToString(peerID) item, ok := p.store.Peer(*peerID)
dp, ok := p.store.Peer(stringPeerID)
if !ok { if !ok {
p.log.Debugf("peer not found: %s", stringPeerID) p.log.Debugf("peer not found: %s", peerID)
return return
} }
dp := item.(*Peer)
err = messages.UpdateTransportMsg(msg, p.idB) err = messages.UpdateTransportMsg(msg, p.id)
if err != nil { if err != nil {
p.log.Errorf("failed to update transport message: %s", err) p.log.Errorf("failed to update transport message: %s", err)
return return
@ -217,3 +232,57 @@ func (p *Peer) handleTransportMsg(msg []byte) {
} }
p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
} }
func (p *Peer) handleSubscribePeerState(msg []byte) {
peerIDs, err := messages.UnmarshalSubPeerStateMsg(msg)
if err != nil {
p.log.Errorf("failed to unmarshal open connection message: %s", err)
return
}
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
if len(onlinePeers) == 0 {
return
}
p.log.Debugf("response with %d online peers", len(onlinePeers))
p.sendPeersOnline(onlinePeers)
}
func (p *Peer) handleUnsubscribePeerState(msg []byte) {
peerIDs, err := messages.UnmarshalUnsubPeerStateMsg(msg)
if err != nil {
p.log.Errorf("failed to unmarshal open connection message: %s", err)
return
}
p.peersListener.RemoveInterestedPeer(peerIDs)
}
func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
msgs, err := messages.MarshalPeersOnline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)
return
}
for n, msg := range msgs {
if _, err := p.Write(msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
}
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
msgs, err := messages.MarshalPeersWentOffline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)
return
}
for n, msg := range msgs {
if _, err := p.Write(msg); err != nil {
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
}
}
}

View File

@ -4,26 +4,55 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/url"
"strings"
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/relay/auth"
//nolint:staticcheck //nolint:staticcheck
"github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
) )
type Config struct {
Meter metric.Meter
ExposedAddress string
TLSSupport bool
AuthValidator Validator
instanceURL string
}
func (c *Config) validate() error {
if c.Meter == nil {
c.Meter = otel.Meter("")
}
if c.ExposedAddress == "" {
return fmt.Errorf("exposed address is required")
}
instanceURL, err := getInstanceURL(c.ExposedAddress, c.TLSSupport)
if err != nil {
return fmt.Errorf("invalid url: %v", err)
}
c.instanceURL = instanceURL
if c.AuthValidator == nil {
return fmt.Errorf("auth validator is required")
}
return nil
}
// Relay represents the relay server // Relay represents the relay server
type Relay struct { type Relay struct {
metrics *metrics.Metrics metrics *metrics.Metrics
metricsCancel context.CancelFunc metricsCancel context.CancelFunc
validator auth.Validator validator Validator
store *Store store *store.Store
notifier *store.PeerNotifier
instanceURL string instanceURL string
preparedMsg *preparedMsg preparedMsg *preparedMsg
@ -31,40 +60,40 @@ type Relay struct {
closeMu sync.RWMutex closeMu sync.RWMutex
} }
// NewRelay creates a new Relay instance // NewRelay creates and returns a new Relay instance.
// //
// Parameters: // Parameters:
// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage //
// metrics for the relay server. // config: A Config struct that holds the configuration needed to initialize the relay server.
// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this // - Meter: A metric.Meter used for emitting metrics. If not set, a default no-op meter will be used.
// address as the relay server's instance URL. // - ExposedAddress: The external address clients use to reach this relay. Required.
// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The // - TLSSupport: A boolean indicating if the relay uses TLS. Affects the generated instance URL.
// instance URL depends on this value. // - AuthValidator: A Validator implementation used to authenticate peers. Required.
// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
// peers.
// //
// Returns: // Returns:
// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil. //
// Otherwise, the error contains the details of what went wrong. // A pointer to a Relay instance and an error. If initialization is successful, the error will be nil;
func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) { // otherwise, it will contain the reason the relay could not be created (e.g., invalid configuration).
func NewRelay(config Config) (*Relay, error) {
if err := config.validate(); err != nil {
return nil, fmt.Errorf("invalid config: %v", err)
}
ctx, metricsCancel := context.WithCancel(context.Background()) ctx, metricsCancel := context.WithCancel(context.Background())
m, err := metrics.NewMetrics(ctx, meter) m, err := metrics.NewMetrics(ctx, config.Meter)
if err != nil { if err != nil {
metricsCancel() metricsCancel()
return nil, fmt.Errorf("creating app metrics: %v", err) return nil, fmt.Errorf("creating app metrics: %v", err)
} }
peerStore := store.NewStore()
r := &Relay{ r := &Relay{
metrics: m, metrics: m,
metricsCancel: metricsCancel, metricsCancel: metricsCancel,
validator: validator, validator: config.AuthValidator,
store: NewStore(), instanceURL: config.instanceURL,
} store: peerStore,
notifier: store.NewPeerNotifier(peerStore),
r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
if err != nil {
metricsCancel()
return nil, fmt.Errorf("get instance URL: %v", err)
} }
r.preparedMsg, err = newPreparedMsg(r.instanceURL) r.preparedMsg, err = newPreparedMsg(r.instanceURL)
@ -76,32 +105,6 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
return r, nil return r, nil
} }
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
// provided address according to TLS definition and parses the address before returning it
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
addr := exposedAddress
split := strings.Split(exposedAddress, "://")
switch {
case len(split) == 1 && tlsSupported:
addr = "rels://" + exposedAddress
case len(split) == 1 && !tlsSupported:
addr = "rel://" + exposedAddress
case len(split) > 2:
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
}
parsedURL, err := url.ParseRequestURI(addr)
if err != nil {
return "", fmt.Errorf("invalid exposed address: %v", err)
}
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
}
return parsedURL.String(), nil
}
// Accept start to handle a new peer connection // Accept start to handle a new peer connection
func (r *Relay) Accept(conn net.Conn) { func (r *Relay) Accept(conn net.Conn) {
acceptTime := time.Now() acceptTime := time.Now()
@ -125,14 +128,17 @@ func (r *Relay) Accept(conn net.Conn) {
return return
} }
peer := NewPeer(r.metrics, peerID, conn, r.store) peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now() storeTime := time.Now()
r.store.AddPeer(peer) r.store.AddPeer(peer)
r.notifier.PeerCameOnline(peer.ID())
r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String()) r.metrics.PeerConnected(peer.String())
go func() { go func() {
peer.Work() peer.Work()
r.notifier.PeerWentOffline(peer.ID())
r.store.DeletePeer(peer) r.store.DeletePeer(peer)
peer.log.Debugf("relay connection closed") peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String()) r.metrics.PeerDisconnected(peer.String())
@ -154,12 +160,12 @@ func (r *Relay) Shutdown(ctx context.Context) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
peers := r.store.Peers() peers := r.store.Peers()
for _, peer := range peers { for _, v := range peers {
wg.Add(1) wg.Add(1)
go func(p *Peer) { go func(p *Peer) {
p.CloseGracefully(ctx) p.CloseGracefully(ctx)
wg.Done() wg.Done()
}(peer) }(v.(*Peer))
} }
wg.Wait() wg.Wait()
r.metricsCancel() r.metricsCancel()

View File

@ -6,15 +6,12 @@ import (
"sync" "sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/quic" "github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws" "github.com/netbirdio/netbird/relay/server/listener/ws"
quictls "github.com/netbirdio/netbird/relay/tls" quictls "github.com/netbirdio/netbird/relay/tls"
log "github.com/sirupsen/logrus"
) )
// ListenerConfig is the configuration for the listener. // ListenerConfig is the configuration for the listener.
@ -33,13 +30,22 @@ type Server struct {
listeners []listener.Listener listeners []listener.Listener
} }
// NewServer creates a new relay server instance. // NewServer creates and returns a new relay server instance.
// meter: the OpenTelemetry meter //
// exposedAddress: this address will be used as the instance URL. It should be a domain:port format. // Parameters:
// tlsSupport: if true, the server will support TLS //
// authValidator: the auth validator to use for the server // config: A Config struct containing the necessary configuration:
func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) { // - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used.
relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator) // - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required.
// - TLSSupport: A boolean indicating whether TLS is enabled for the server.
// - AuthValidator: A Validator used to authenticate peers. Required.
//
// Returns:
//
// A pointer to a Server instance and an error. If the configuration is valid and initialization succeeds,
// the returned error will be nil. Otherwise, the error will describe the problem.
func NewServer(config Config) (*Server, error) {
relay, err := NewRelay(config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -0,0 +1,121 @@
package store
import (
"context"
"sync"
"github.com/netbirdio/netbird/relay/messages"
)
type Listener struct {
store *Store
onlineChan chan messages.PeerID
offlineChan chan messages.PeerID
interestedPeersForOffline map[messages.PeerID]struct{}
interestedPeersForOnline map[messages.PeerID]struct{}
mu sync.RWMutex
listenerCtx context.Context
}
func newListener(store *Store) *Listener {
l := &Listener{
store: store,
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
}
return l
}
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID {
availablePeers := make([]messages.PeerID, 0)
l.mu.Lock()
defer l.mu.Unlock()
for _, id := range peerIDs {
l.interestedPeersForOnline[id] = struct{}{}
l.interestedPeersForOffline[id] = struct{}{}
}
// collect online peers to response back to the caller
for _, id := range peerIDs {
_, ok := l.store.Peer(id)
if !ok {
continue
}
availablePeers = append(availablePeers, id)
}
return availablePeers
}
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
l.mu.Lock()
defer l.mu.Unlock()
for _, id := range peerIDs {
delete(l.interestedPeersForOffline, id)
delete(l.interestedPeersForOnline, id)
}
}
func (l *Listener) listenForEvents(ctx context.Context, onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) {
l.listenerCtx = ctx
for {
select {
case <-ctx.Done():
return
case pID := <-l.onlineChan:
peers := make([]messages.PeerID, 0)
peers = append(peers, pID)
for len(l.onlineChan) > 0 {
pID = <-l.onlineChan
peers = append(peers, pID)
}
onPeersComeOnline(peers)
case pID := <-l.offlineChan:
peers := make([]messages.PeerID, 0)
peers = append(peers, pID)
for len(l.offlineChan) > 0 {
pID = <-l.offlineChan
peers = append(peers, pID)
}
onPeersWentOffline(peers)
}
}
}
func (l *Listener) peerWentOffline(peerID messages.PeerID) {
l.mu.RLock()
defer l.mu.RUnlock()
if _, ok := l.interestedPeersForOffline[peerID]; ok {
select {
case l.offlineChan <- peerID:
case <-l.listenerCtx.Done():
}
}
}
func (l *Listener) peerComeOnline(peerID messages.PeerID) {
l.mu.Lock()
defer l.mu.Unlock()
if _, ok := l.interestedPeersForOnline[peerID]; ok {
select {
case l.onlineChan <- peerID:
case <-l.listenerCtx.Done():
}
delete(l.interestedPeersForOnline, peerID)
}
}

View File

@ -0,0 +1,64 @@
package store
import (
"context"
"sync"
"github.com/netbirdio/netbird/relay/messages"
)
type PeerNotifier struct {
store *Store
listeners map[*Listener]context.CancelFunc
listenersMutex sync.RWMutex
}
func NewPeerNotifier(store *Store) *PeerNotifier {
pn := &PeerNotifier{
store: store,
listeners: make(map[*Listener]context.CancelFunc),
}
return pn
}
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
ctx, cancel := context.WithCancel(context.Background())
listener := newListener(pn.store)
go listener.listenForEvents(ctx, onPeersComeOnline, onPeersWentOffline)
pn.listenersMutex.Lock()
pn.listeners[listener] = cancel
pn.listenersMutex.Unlock()
return listener
}
func (pn *PeerNotifier) RemoveListener(listener *Listener) {
pn.listenersMutex.Lock()
defer pn.listenersMutex.Unlock()
cancel, ok := pn.listeners[listener]
if !ok {
return
}
cancel()
delete(pn.listeners, listener)
}
func (pn *PeerNotifier) PeerWentOffline(peerID messages.PeerID) {
pn.listenersMutex.RLock()
defer pn.listenersMutex.RUnlock()
for listener := range pn.listeners {
listener.peerWentOffline(peerID)
}
}
func (pn *PeerNotifier) PeerCameOnline(peerID messages.PeerID) {
pn.listenersMutex.RLock()
defer pn.listenersMutex.RUnlock()
for listener := range pn.listeners {
listener.peerComeOnline(peerID)
}
}

View File

@ -1,41 +1,48 @@
package server package store
import ( import (
"sync" "sync"
"github.com/netbirdio/netbird/relay/messages"
) )
type IPeer interface {
Close()
ID() messages.PeerID
}
// Store is a thread-safe store of peers // Store is a thread-safe store of peers
// It is used to store the peers that are connected to the relay server // It is used to store the peers that are connected to the relay server
type Store struct { type Store struct {
peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster peers map[messages.PeerID]IPeer
peersLock sync.RWMutex peersLock sync.RWMutex
} }
// NewStore creates a new Store instance // NewStore creates a new Store instance
func NewStore() *Store { func NewStore() *Store {
return &Store{ return &Store{
peers: make(map[string]*Peer), peers: make(map[messages.PeerID]IPeer),
} }
} }
// AddPeer adds a peer to the store // AddPeer adds a peer to the store
func (s *Store) AddPeer(peer *Peer) { func (s *Store) AddPeer(peer IPeer) {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.String()] odlPeer, ok := s.peers[peer.ID()]
if ok { if ok {
odlPeer.Close() odlPeer.Close()
} }
s.peers[peer.String()] = peer s.peers[peer.ID()] = peer
} }
// DeletePeer deletes a peer from the store // DeletePeer deletes a peer from the store
func (s *Store) DeletePeer(peer *Peer) { func (s *Store) DeletePeer(peer IPeer) {
s.peersLock.Lock() s.peersLock.Lock()
defer s.peersLock.Unlock() defer s.peersLock.Unlock()
dp, ok := s.peers[peer.String()] dp, ok := s.peers[peer.ID()]
if !ok { if !ok {
return return
} }
@ -43,11 +50,11 @@ func (s *Store) DeletePeer(peer *Peer) {
return return
} }
delete(s.peers, peer.String()) delete(s.peers, peer.ID())
} }
// Peer returns a peer by its ID // Peer returns a peer by its ID
func (s *Store) Peer(id string) (*Peer, bool) { func (s *Store) Peer(id messages.PeerID) (IPeer, bool) {
s.peersLock.RLock() s.peersLock.RLock()
defer s.peersLock.RUnlock() defer s.peersLock.RUnlock()
@ -56,11 +63,11 @@ func (s *Store) Peer(id string) (*Peer, bool) {
} }
// Peers returns all the peers in the store // Peers returns all the peers in the store
func (s *Store) Peers() []*Peer { func (s *Store) Peers() []IPeer {
s.peersLock.RLock() s.peersLock.RLock()
defer s.peersLock.RUnlock() defer s.peersLock.RUnlock()
peers := make([]*Peer, 0, len(s.peers)) peers := make([]IPeer, 0, len(s.peers))
for _, p := range s.peers { for _, p := range s.peers {
peers = append(peers, p) peers = append(peers, p)
} }

View File

@ -0,0 +1,49 @@
package store
import (
"testing"
"github.com/netbirdio/netbird/relay/messages"
)
type MocPeer struct {
id messages.PeerID
}
func (m *MocPeer) Close() {
}
func (m *MocPeer) ID() messages.PeerID {
return m.id
}
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
pID := messages.HashID("peer_one")
p := &MocPeer{id: pID}
s.AddPeer(p)
s.DeletePeer(p)
if _, ok := s.Peer(pID); ok {
t.Errorf("peer was not deleted")
}
}
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
s := NewStore()
pID1 := messages.HashID("peer_one")
pID2 := messages.HashID("peer_one")
p1 := &MocPeer{id: pID1}
p2 := &MocPeer{id: pID2}
s.AddPeer(p1)
s.AddPeer(p2)
s.DeletePeer(p1)
if _, ok := s.Peer(pID2); !ok {
t.Errorf("second peer was deleted")
}
}

View File

@ -1,85 +0,0 @@
package server
import (
"context"
"net"
"testing"
"time"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/metrics"
)
type mockConn struct {
}
func (m mockConn) Read(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Write(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Close() error {
return nil
}
func (m mockConn) LocalAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) RemoteAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
p := NewPeer(m, []byte("peer_one"), nil, nil)
s.AddPeer(p)
s.DeletePeer(p)
if _, ok := s.Peer(p.String()); ok {
t.Errorf("peer was not deleted")
}
}
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
s := NewStore()
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
conn := &mockConn{}
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
s.AddPeer(p1)
s.AddPeer(p2)
s.DeletePeer(p1)
if _, ok := s.Peer(p2.String()); !ok {
t.Errorf("second peer was deleted")
}
}

33
relay/server/url.go Normal file
View File

@ -0,0 +1,33 @@
package server
import (
"fmt"
"net/url"
"strings"
)
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
// provided address according to TLS definition and parses the address before returning it
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
addr := exposedAddress
split := strings.Split(exposedAddress, "://")
switch {
case len(split) == 1 && tlsSupported:
addr = "rels://" + exposedAddress
case len(split) == 1 && !tlsSupported:
addr = "rel://" + exposedAddress
case len(split) > 2:
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
}
parsedURL, err := url.ParseRequestURI(addr)
if err != nil {
return "", fmt.Errorf("invalid exposed address: %v", err)
}
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
}
return parsedURL.String(), nil
}

View File

@ -12,7 +12,6 @@ import (
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/turn/v3" "github.com/pion/turn/v3"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/auth/allow" "github.com/netbirdio/netbird/relay/auth/allow"
"github.com/netbirdio/netbird/relay/auth/hmac" "github.com/netbirdio/netbird/relay/auth/hmac"
@ -22,7 +21,6 @@ import (
) )
var ( var (
av = &allow.Auth{}
hmacTokenStore = &hmac.TokenStore{} hmacTokenStore = &hmac.TokenStore{}
pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100} pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
dataSize = 1024 * 1024 * 10 dataSize = 1024 * 1024 * 10
@ -70,8 +68,12 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
port := 35000 + peerPairs port := 35000 + peerPairs
serverAddress := fmt.Sprintf("127.0.0.1:%d", port) serverAddress := fmt.Sprintf("127.0.0.1:%d", port)
serverConnURL := fmt.Sprintf("rel://%s", serverAddress) serverConnURL := fmt.Sprintf("rel://%s", serverAddress)
serverCfg := server.Config{
srv, err := server.NewServer(otel.Meter(""), serverConnURL, false, av) ExposedAddress: serverConnURL,
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
srv, err := server.NewServer(serverCfg)
if err != nil { if err != nil {
t.Fatalf("failed to create server: %s", err) t.Fatalf("failed to create server: %s", err)
} }
@ -98,8 +100,8 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
clientsSender := make([]*client.Client, peerPairs) clientsSender := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsSender); i++ { for i := 0; i < cap(clientsSender); i++ {
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
err := c.Connect() err := c.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@ -108,8 +110,8 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
clientsReceiver := make([]*client.Client, peerPairs) clientsReceiver := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsReceiver); i++ { for i := 0; i < cap(clientsReceiver); i++ {
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
err := c.Connect() err := c.Connect(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to connect to server: %s", err) t.Fatalf("failed to connect to server: %s", err)
} }
@ -119,13 +121,13 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
connsSender := make([]net.Conn, 0, peerPairs) connsSender := make([]net.Conn, 0, peerPairs)
connsReceiver := make([]net.Conn, 0, peerPairs) connsReceiver := make([]net.Conn, 0, peerPairs)
for i := 0; i < len(clientsSender); i++ { for i := 0; i < len(clientsSender); i++ {
conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) conn, err := clientsSender[i].OpenConn(ctx, "receiver-"+fmt.Sprint(i))
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }
connsSender = append(connsSender, conn) connsSender = append(connsSender, conn)
conn, err = clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) conn, err = clientsReceiver[i].OpenConn(ctx, "sender-"+fmt.Sprint(i))
if err != nil { if err != nil {
t.Fatalf("failed to bind channel: %s", err) t.Fatalf("failed to bind channel: %s", err)
} }

View File

@ -70,8 +70,8 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn {
ctx := context.Background() ctx := context.Background()
clientsSender := make([]*client.Client, peerPairs) clientsSender := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsSender); i++ { for i := 0; i < cap(clientsSender); i++ {
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
if err := c.Connect(); err != nil { if err := c.Connect(ctx); err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
clientsSender[i] = c clientsSender[i] = c
@ -79,7 +79,7 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn {
connsSender := make([]net.Conn, 0, peerPairs) connsSender := make([]net.Conn, 0, peerPairs)
for i := 0; i < len(clientsSender); i++ { for i := 0; i < len(clientsSender); i++ {
conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) conn, err := clientsSender[i].OpenConn(ctx, "receiver-"+fmt.Sprint(i))
if err != nil { if err != nil {
log.Fatalf("failed to bind channel: %s", err) log.Fatalf("failed to bind channel: %s", err)
} }
@ -156,8 +156,8 @@ func runReader(conn net.Conn) time.Duration {
func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn { func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
clientsReceiver := make([]*client.Client, peerPairs) clientsReceiver := make([]*client.Client, peerPairs)
for i := 0; i < cap(clientsReceiver); i++ { for i := 0; i < cap(clientsReceiver); i++ {
c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
err := c.Connect() err := c.Connect(context.Background())
if err != nil { if err != nil {
log.Fatalf("failed to connect to server: %s", err) log.Fatalf("failed to connect to server: %s", err)
} }
@ -166,7 +166,7 @@ func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
connsReceiver := make([]net.Conn, 0, peerPairs) connsReceiver := make([]net.Conn, 0, peerPairs)
for i := 0; i < len(clientsReceiver); i++ { for i := 0; i < len(clientsReceiver); i++ {
conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) conn, err := clientsReceiver[i].OpenConn(context.Background(), "sender-"+fmt.Sprint(i))
if err != nil { if err != nil {
log.Fatalf("failed to bind channel: %s", err) log.Fatalf("failed to bind channel: %s", err)
} }