mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-13 17:07:30 +02:00
[client, relay-server] Feature/relay notification (#4083)
- Clients now subscribe to peer status changes. - The server manages and maintains these subscriptions. - Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity.
This commit is contained in:
@ -12,6 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProxyBind struct {
|
type ProxyBind struct {
|
||||||
@ -28,6 +29,17 @@ type ProxyBind struct {
|
|||||||
pausedMu sync.Mutex
|
pausedMu sync.Mutex
|
||||||
paused bool
|
paused bool
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
|
closeListener *listener.CloseListener
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
||||||
|
p := &ProxyBind{
|
||||||
|
Bind: bind,
|
||||||
|
closeListener: listener.NewCloseListener(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn adds a new connection to the bind.
|
// AddTurnConn adds a new connection to the bind.
|
||||||
@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
||||||
|
p.closeListener.SetCloseListener(disconnected)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) Work() {
|
func (p *ProxyBind) Work() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
return
|
return
|
||||||
@ -96,6 +112,9 @@ func (p *ProxyBind) close() error {
|
|||||||
if p.closed {
|
if p.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.closeListener.SetCloseListener(nil)
|
||||||
|
|
||||||
p.closed = true
|
p.closed = true
|
||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
p.closeListener.Notify()
|
||||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
@ -26,6 +28,15 @@ type ProxyWrapper struct {
|
|||||||
pausedMu sync.Mutex
|
pausedMu sync.Mutex
|
||||||
paused bool
|
paused bool
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
|
closeListener *listener.CloseListener
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
|
||||||
|
return &ProxyWrapper{
|
||||||
|
WgeBPFProxy: WgeBPFProxy,
|
||||||
|
closeListener: listener.NewCloseListener(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
|||||||
return p.wgEndpointAddr
|
return p.wgEndpointAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
||||||
|
p.closeListener.SetCloseListener(disconnected)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) Work() {
|
func (p *ProxyWrapper) Work() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
return
|
return
|
||||||
@ -77,6 +92,8 @@ func (e *ProxyWrapper) CloseConn() error {
|
|||||||
|
|
||||||
e.cancel()
|
e.cancel()
|
||||||
|
|
||||||
|
e.closeListener.SetCloseListener(nil)
|
||||||
|
|
||||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||||
}
|
}
|
||||||
@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return 0, ctx.Err()
|
return 0, ctx.Err()
|
||||||
}
|
}
|
||||||
|
p.closeListener.Notify()
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||||
}
|
}
|
||||||
|
@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy {
|
|||||||
return udpProxy.NewWGUDPProxy(w.wgPort)
|
return udpProxy.NewWGUDPProxy(w.wgPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ebpf.ProxyWrapper{
|
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||||
WgeBPFProxy: w.ebpfProxy,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *KernelFactory) Free() error {
|
func (w *KernelFactory) Free() error {
|
||||||
|
@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *USPFactory) GetProxy() Proxy {
|
func (w *USPFactory) GetProxy() Proxy {
|
||||||
return &proxyBind.ProxyBind{
|
return proxyBind.NewProxyBind(w.bind)
|
||||||
Bind: w.bind,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *USPFactory) Free() error {
|
func (w *USPFactory) Free() error {
|
||||||
|
19
client/iface/wgproxy/listener/listener.go
Normal file
19
client/iface/wgproxy/listener/listener.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package listener
|
||||||
|
|
||||||
|
type CloseListener struct {
|
||||||
|
listener func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCloseListener() *CloseListener {
|
||||||
|
return &CloseListener{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CloseListener) SetCloseListener(listener func()) {
|
||||||
|
c.listener = listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CloseListener) Notify() {
|
||||||
|
if c.listener != nil {
|
||||||
|
c.listener()
|
||||||
|
}
|
||||||
|
}
|
@ -12,4 +12,5 @@ type Proxy interface {
|
|||||||
Work() // Work start or resume the proxy
|
Work() // Work start or resume the proxy
|
||||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
|
SetDisconnectListener(disconnected func())
|
||||||
}
|
}
|
||||||
|
@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
|||||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
proxyWrapper := &ebpf.ProxyWrapper{
|
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
|
||||||
WgeBPFProxy: ebpfProxy,
|
|
||||||
}
|
|
||||||
|
|
||||||
tests = append(tests, struct {
|
tests = append(tests, struct {
|
||||||
name string
|
name string
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
cerrors "github.com/netbirdio/netbird/client/errors"
|
cerrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGUDPProxy proxies
|
// WGUDPProxy proxies
|
||||||
@ -28,6 +29,8 @@ type WGUDPProxy struct {
|
|||||||
pausedMu sync.Mutex
|
pausedMu sync.Mutex
|
||||||
paused bool
|
paused bool
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
|
closeListener *listener.CloseListener
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
|
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
|
||||||
@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
|
|||||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||||
p := &WGUDPProxy{
|
p := &WGUDPProxy{
|
||||||
localWGListenPort: wgPort,
|
localWGListenPort: wgPort,
|
||||||
|
closeListener: listener.NewCloseListener(),
|
||||||
}
|
}
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
|
|||||||
return endpointUdpAddr
|
return endpointUdpAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
|
||||||
|
p.closeListener.SetCloseListener(disconnected)
|
||||||
|
}
|
||||||
|
|
||||||
// Work starts the proxy or resumes it if it was paused
|
// Work starts the proxy or resumes it if it was paused
|
||||||
func (p *WGUDPProxy) Work() {
|
func (p *WGUDPProxy) Work() {
|
||||||
if p.remoteConn == nil {
|
if p.remoteConn == nil {
|
||||||
@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error {
|
|||||||
if p.closed {
|
if p.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.closeListener.SetCloseListener(nil)
|
||||||
p.closed = true
|
p.closed = true
|
||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
p.closeListener.Notify()
|
||||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -167,7 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
|
|
||||||
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
||||||
|
|
||||||
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
||||||
|
|
||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||||
@ -489,6 +489,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
wgProxy.SetDisconnectListener(conn.onRelayDisconnected)
|
||||||
|
|
||||||
conn.dumpState.NewLocalProxy()
|
conn.dumpState.NewLocalProxy()
|
||||||
|
|
||||||
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||||
|
@ -19,6 +19,7 @@ type RelayConnInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type WorkerRelay struct {
|
type WorkerRelay struct {
|
||||||
|
peerCtx context.Context
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
isController bool
|
isController bool
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
@ -33,8 +34,9 @@ type WorkerRelay struct {
|
|||||||
wgWatcher *WGWatcher
|
wgWatcher *WGWatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
|
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
|
||||||
r := &WorkerRelay{
|
r := &WorkerRelay{
|
||||||
|
peerCtx: ctx,
|
||||||
log: log,
|
log: log,
|
||||||
isController: ctrl,
|
isController: ctrl,
|
||||||
config: config,
|
config: config,
|
||||||
@ -62,7 +64,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
|
|
||||||
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
||||||
|
|
||||||
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
|
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
w.log.Debugf("handled offer by reusing existing relay connection")
|
||||||
|
@ -7,13 +7,6 @@ import (
|
|||||||
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Validator is an interface that defines the Validate method.
|
|
||||||
type Validator interface {
|
|
||||||
Validate(any) error
|
|
||||||
// Deprecated: Use Validate instead.
|
|
||||||
ValidateHelloMsgType(any) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type TimedHMACValidator struct {
|
type TimedHMACValidator struct {
|
||||||
authenticatorV2 *authv2.Validator
|
authenticatorV2 *authv2.Validator
|
||||||
authenticator *auth.TimedHMACValidator
|
authenticator *auth.TimedHMACValidator
|
||||||
|
@ -124,15 +124,14 @@ func (cc *connContainer) close() {
|
|||||||
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
|
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
parentCtx context.Context
|
|
||||||
connectionURL string
|
connectionURL string
|
||||||
authTokenStore *auth.TokenStore
|
authTokenStore *auth.TokenStore
|
||||||
hashedID []byte
|
hashedID messages.PeerID
|
||||||
|
|
||||||
bufPool *sync.Pool
|
bufPool *sync.Pool
|
||||||
|
|
||||||
relayConn net.Conn
|
relayConn net.Conn
|
||||||
conns map[string]*connContainer
|
conns map[messages.PeerID]*connContainer
|
||||||
serviceIsRunning bool
|
serviceIsRunning bool
|
||||||
mu sync.Mutex // protect serviceIsRunning and conns
|
mu sync.Mutex // protect serviceIsRunning and conns
|
||||||
readLoopMutex sync.Mutex
|
readLoopMutex sync.Mutex
|
||||||
@ -142,14 +141,17 @@ type Client struct {
|
|||||||
|
|
||||||
onDisconnectListener func(string)
|
onDisconnectListener func(string)
|
||||||
listenerMutex sync.Mutex
|
listenerMutex sync.Mutex
|
||||||
|
|
||||||
|
stateSubscription *PeersStateSubscription
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||||
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||||
hashedID, hashedStringId := messages.HashID(peerID)
|
hashedID := messages.HashID(peerID)
|
||||||
|
relayLog := log.WithFields(log.Fields{"relay": serverURL})
|
||||||
|
|
||||||
c := &Client{
|
c := &Client{
|
||||||
log: log.WithFields(log.Fields{"relay": serverURL}),
|
log: relayLog,
|
||||||
parentCtx: ctx,
|
|
||||||
connectionURL: serverURL,
|
connectionURL: serverURL,
|
||||||
authTokenStore: authTokenStore,
|
authTokenStore: authTokenStore,
|
||||||
hashedID: hashedID,
|
hashedID: hashedID,
|
||||||
@ -159,14 +161,15 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
|
|||||||
return &buf
|
return &buf
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
conns: make(map[string]*connContainer),
|
conns: make(map[messages.PeerID]*connContainer),
|
||||||
}
|
}
|
||||||
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
|
|
||||||
|
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect(ctx context.Context) error {
|
||||||
c.log.Infof("connecting to relay server")
|
c.log.Infof("connecting to relay server")
|
||||||
c.readLoopMutex.Lock()
|
c.readLoopMutex.Lock()
|
||||||
defer c.readLoopMutex.Unlock()
|
defer c.readLoopMutex.Unlock()
|
||||||
@ -178,17 +181,23 @@ func (c *Client) Connect() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.connect(); err != nil {
|
if err := c.connect(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
|
||||||
|
|
||||||
c.log = c.log.WithField("relay", c.instanceURL.String())
|
c.log = c.log.WithField("relay", c.instanceURL.String())
|
||||||
c.log.Infof("relay connection established")
|
c.log.Infof("relay connection established")
|
||||||
|
|
||||||
c.serviceIsRunning = true
|
c.serviceIsRunning = true
|
||||||
|
|
||||||
|
internallyStoppedFlag := newInternalStopFlag()
|
||||||
|
hc := healthcheck.NewReceiver(c.log)
|
||||||
|
go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag)
|
||||||
|
|
||||||
c.wgReadLoop.Add(1)
|
c.wgReadLoop.Add(1)
|
||||||
go c.readLoop(c.relayConn)
|
go c.readLoop(hc, c.relayConn, internallyStoppedFlag)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -196,26 +205,41 @@ func (c *Client) Connect() error {
|
|||||||
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
|
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
|
||||||
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
|
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
|
||||||
// it will return immediately.
|
// it will return immediately.
|
||||||
|
// It block until the server confirm the peer is online.
|
||||||
// todo: what should happen if call with the same peerID with multiple times?
|
// todo: what should happen if call with the same peerID with multiple times?
|
||||||
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
|
||||||
c.mu.Lock()
|
peerID := messages.HashID(dstPeerID)
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
if !c.serviceIsRunning {
|
if !c.serviceIsRunning {
|
||||||
|
c.mu.Unlock()
|
||||||
return nil, fmt.Errorf("relay connection is not established")
|
return nil, fmt.Errorf("relay connection is not established")
|
||||||
}
|
}
|
||||||
|
_, ok := c.conns[peerID]
|
||||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
|
||||||
_, ok := c.conns[hashedStringID]
|
|
||||||
if ok {
|
if ok {
|
||||||
|
c.mu.Unlock()
|
||||||
return nil, ErrConnAlreadyExists
|
return nil, ErrConnAlreadyExists
|
||||||
}
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
c.log.Infof("open connection to peer: %s", hashedStringID)
|
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
|
||||||
|
c.log.Errorf("peer not available: %s, %s", peerID, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
|
||||||
msgChannel := make(chan Msg, 100)
|
msgChannel := make(chan Msg, 100)
|
||||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
conn := NewConn(c, peerID, msgChannel, c.instanceURL)
|
||||||
|
|
||||||
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
|
c.mu.Lock()
|
||||||
|
_, ok = c.conns[peerID]
|
||||||
|
if ok {
|
||||||
|
c.mu.Unlock()
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, ErrConnAlreadyExists
|
||||||
|
}
|
||||||
|
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
|
||||||
|
c.mu.Unlock()
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -254,7 +278,7 @@ func (c *Client) Close() error {
|
|||||||
return c.close(true)
|
return c.close(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) connect() error {
|
func (c *Client) connect(ctx context.Context) error {
|
||||||
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -262,7 +286,7 @@ func (c *Client) connect() error {
|
|||||||
}
|
}
|
||||||
c.relayConn = conn
|
c.relayConn = conn
|
||||||
|
|
||||||
if err = c.handShake(); err != nil {
|
if err = c.handShake(ctx); err != nil {
|
||||||
cErr := conn.Close()
|
cErr := conn.Close()
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
c.log.Errorf("failed to close connection: %s", cErr)
|
c.log.Errorf("failed to close connection: %s", cErr)
|
||||||
@ -273,7 +297,7 @@ func (c *Client) connect() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handShake() error {
|
func (c *Client) handShake(ctx context.Context) error {
|
||||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to marshal auth message: %s", err)
|
c.log.Errorf("failed to marshal auth message: %s", err)
|
||||||
@ -286,7 +310,7 @@ func (c *Client) handShake() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
buf := make([]byte, messages.MaxHandshakeRespSize)
|
buf := make([]byte, messages.MaxHandshakeRespSize)
|
||||||
n, err := c.readWithTimeout(buf)
|
n, err := c.readWithTimeout(ctx, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to read auth response: %s", err)
|
c.log.Errorf("failed to read auth response: %s", err)
|
||||||
return err
|
return err
|
||||||
@ -319,11 +343,7 @@ func (c *Client) handShake() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) readLoop(relayConn net.Conn) {
|
func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
|
||||||
internallyStoppedFlag := newInternalStopFlag()
|
|
||||||
hc := healthcheck.NewReceiver(c.log)
|
|
||||||
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errExit error
|
errExit error
|
||||||
n int
|
n int
|
||||||
@ -370,6 +390,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
|||||||
c.instanceURL = nil
|
c.instanceURL = nil
|
||||||
c.muInstanceURL.Unlock()
|
c.muInstanceURL.Unlock()
|
||||||
|
|
||||||
|
c.stateSubscription.Cleanup()
|
||||||
c.wgReadLoop.Done()
|
c.wgReadLoop.Done()
|
||||||
_ = c.close(false)
|
_ = c.close(false)
|
||||||
c.notifyDisconnected()
|
c.notifyDisconnected()
|
||||||
@ -382,6 +403,14 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
|
|||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
case messages.MsgTypeTransport:
|
case messages.MsgTypeTransport:
|
||||||
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
||||||
|
case messages.MsgTypePeersOnline:
|
||||||
|
c.handlePeersOnlineMsg(buf)
|
||||||
|
c.bufPool.Put(bufPtr)
|
||||||
|
return true
|
||||||
|
case messages.MsgTypePeersWentOffline:
|
||||||
|
c.handlePeersWentOfflineMsg(buf)
|
||||||
|
c.bufPool.Put(bufPtr)
|
||||||
|
return true
|
||||||
case messages.MsgTypeClose:
|
case messages.MsgTypeClose:
|
||||||
c.log.Debugf("relay connection close by server")
|
c.log.Debugf("relay connection close by server")
|
||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
@ -413,18 +442,16 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
stringID := messages.HashIDToString(peerID)
|
|
||||||
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
if !c.serviceIsRunning {
|
if !c.serviceIsRunning {
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
container, ok := c.conns[stringID]
|
container, ok := c.conns[*peerID]
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
c.log.Errorf("peer not found: %s", stringID)
|
c.log.Errorf("peer not found: %s", peerID.String())
|
||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -437,9 +464,9 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
|
func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
conn, ok := c.conns[id]
|
conn, ok := c.conns[dstID]
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, net.ErrClosed
|
return 0, net.ErrClosed
|
||||||
@ -464,7 +491,7 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
|
|||||||
return len(payload), err
|
return len(payload), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case _, ok := <-hc.OnTimeout:
|
case _, ok := <-hc.OnTimeout:
|
||||||
@ -478,7 +505,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
|
|||||||
c.log.Warnf("failed to close connection: %s", err)
|
c.log.Warnf("failed to close connection: %s", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-c.parentCtx.Done():
|
case <-ctx.Done():
|
||||||
err := c.close(true)
|
err := c.close(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("failed to teardown connection: %s", err)
|
c.log.Errorf("failed to teardown connection: %s", err)
|
||||||
@ -492,10 +519,31 @@ func (c *Client) closeAllConns() {
|
|||||||
for _, container := range c.conns {
|
for _, container := range c.conns {
|
||||||
container.close()
|
container.close()
|
||||||
}
|
}
|
||||||
c.conns = make(map[string]*connContainer)
|
c.conns = make(map[messages.PeerID]*connContainer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) closeConn(connReference *Conn, id string) error {
|
func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
container, ok := c.conns[peerID]
|
||||||
|
if !ok {
|
||||||
|
c.log.Warnf("can not close connection, peer not found: %s", peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID)
|
||||||
|
container.close()
|
||||||
|
delete(c.conns, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil {
|
||||||
|
c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
@ -507,6 +555,11 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
|||||||
if container.conn != connReference {
|
if container.conn != connReference {
|
||||||
return fmt.Errorf("conn reference mismatch")
|
return fmt.Errorf("conn reference mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil {
|
||||||
|
container.log.Errorf("failed to unsubscribe from peer state change: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
c.log.Infof("free up connection to peer: %s", id)
|
c.log.Infof("free up connection to peer: %s", id)
|
||||||
delete(c.conns, id)
|
delete(c.conns, id)
|
||||||
container.close()
|
container.close()
|
||||||
@ -559,8 +612,8 @@ func (c *Client) writeCloseMsg() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) readWithTimeout(buf []byte) (int, error) {
|
func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) {
|
||||||
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
|
ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
readDone := make(chan struct{})
|
readDone := make(chan struct{})
|
||||||
@ -581,3 +634,21 @@ func (c *Client) readWithTimeout(buf []byte) (int, error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) handlePeersOnlineMsg(buf []byte) {
|
||||||
|
peersID, err := messages.UnmarshalPeersOnlineMsg(buf)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Errorf("failed to unmarshal peers online msg: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.stateSubscription.OnPeersOnline(peersID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
|
||||||
|
peersID, err := messages.UnMarshalPeersWentOffline(buf)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Errorf("failed to unmarshal peers went offline msg: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.stateSubscription.OnPeersWentOffline(peersID)
|
||||||
|
}
|
||||||
|
@ -18,14 +18,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
av = &allow.Auth{}
|
|
||||||
hmacTokenStore = &hmac.TokenStore{}
|
hmacTokenStore = &hmac.TokenStore{}
|
||||||
serverListenAddr = "127.0.0.1:1234"
|
serverListenAddr = "127.0.0.1:1234"
|
||||||
serverURL = "rel://127.0.0.1:1234"
|
serverURL = "rel://127.0.0.1:1234"
|
||||||
|
serverCfg = server.Config{
|
||||||
|
Meter: otel.Meter(""),
|
||||||
|
ExposedAddress: serverURL,
|
||||||
|
TLSSupport: false,
|
||||||
|
AuthValidator: &allow.Auth{},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
_ = util.InitLog("error", "console")
|
_ = util.InitLog("debug", "console")
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
@ -33,7 +38,7 @@ func TestMain(m *testing.M) {
|
|||||||
func TestClient(t *testing.T) {
|
func TestClient(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -58,37 +63,37 @@ func TestClient(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
t.Log("alice connecting to server")
|
t.Log("alice connecting to server")
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer clientAlice.Close()
|
defer clientAlice.Close()
|
||||||
|
|
||||||
t.Log("placeholder connecting to server")
|
t.Log("placeholder connecting to server")
|
||||||
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder")
|
clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder")
|
||||||
err = clientPlaceHolder.Connect()
|
err = clientPlaceHolder.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer clientPlaceHolder.Close()
|
defer clientPlaceHolder.Close()
|
||||||
|
|
||||||
t.Log("Bob connecting to server")
|
t.Log("Bob connecting to server")
|
||||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
|
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
defer clientBob.Close()
|
defer clientBob.Close()
|
||||||
|
|
||||||
t.Log("Alice open connection to Bob")
|
t.Log("Alice open connection to Bob")
|
||||||
connAliceToBob, err := clientAlice.OpenConn("bob")
|
connAliceToBob, err := clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("Bob open connection to Alice")
|
t.Log("Bob open connection to Alice")
|
||||||
connBobToAlice, err := clientBob.OpenConn("alice")
|
connBobToAlice, err := clientBob.OpenConn(ctx, "alice")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -115,7 +120,7 @@ func TestClient(t *testing.T) {
|
|||||||
func TestRegistration(t *testing.T) {
|
func TestRegistration(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -132,8 +137,8 @@ func TestRegistration(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = srv.Shutdown(ctx)
|
_ = srv.Shutdown(ctx)
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
@ -172,8 +177,8 @@ func TestRegistrationTimeout(t *testing.T) {
|
|||||||
_ = fakeTCPListener.Close()
|
_ = fakeTCPListener.Close()
|
||||||
}(fakeTCPListener)
|
}(fakeTCPListener)
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice")
|
clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@ -189,7 +194,7 @@ func TestEcho(t *testing.T) {
|
|||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
idBob := "bob"
|
idBob := "bob"
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -213,8 +218,8 @@ func TestEcho(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@ -225,8 +230,8 @@ func TestEcho(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
|
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@ -237,12 +242,12 @@ func TestEcho(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
connAliceToBob, err := clientAlice.OpenConn(idBob)
|
connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
connBobToAlice, err := clientBob.OpenConn(idAlice)
|
connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -278,7 +283,7 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -303,14 +308,14 @@ func TestBindToUnavailabePeer(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
_, err = clientAlice.OpenConn("bob")
|
_, err = clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err == nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("expected error when binding to unavailable peer, got nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("closing client")
|
log.Infof("closing client")
|
||||||
@ -324,7 +329,7 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -349,24 +354,24 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||||
|
err = clientBob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = clientAlice.OpenConn("bob")
|
_, err = clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
|
chBob, err := clientBob.OpenConn(ctx, "alice")
|
||||||
err = clientBob.Connect()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chBob, err := clientBob.OpenConn("alice")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -377,18 +382,28 @@ func TestBindReconnect(t *testing.T) {
|
|||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
clientAlice = NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
chAlice, err := clientAlice.OpenConn("bob")
|
chAlice, err := clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
testString := "hello alice, I am bob"
|
testString := "hello alice, I am bob"
|
||||||
|
_, err = chBob.Write([]byte(testString))
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error when writing to channel, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
chBob, err = clientBob.OpenConn(ctx, "alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
_, err = chBob.Write([]byte(testString))
|
_, err = chBob.Write([]byte(testString))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to write to channel: %s", err)
|
t.Errorf("failed to write to channel: %s", err)
|
||||||
@ -415,7 +430,7 @@ func TestCloseConn(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -440,13 +455,19 @@ func TestCloseConn(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
bob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||||
err = clientAlice.Connect()
|
err = bob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to connect to server: %s", err)
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := clientAlice.OpenConn("bob")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
|
err = clientAlice.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -472,7 +493,7 @@ func TestCloseRelayConn(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -496,13 +517,19 @@ func TestCloseRelayConn(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
|
bob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||||
err = clientAlice.Connect()
|
err = bob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := clientAlice.OpenConn("bob")
|
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||||
|
err = clientAlice.Connect(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := clientAlice.OpenConn(ctx, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -514,7 +541,7 @@ func TestCloseRelayConn(t *testing.T) {
|
|||||||
t.Errorf("unexpected reading from closed connection")
|
t.Errorf("unexpected reading from closed connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = clientAlice.OpenConn("bob")
|
_, err = clientAlice.OpenConn(ctx, "bob")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("unexpected opening connection to closed server")
|
t.Errorf("unexpected opening connection to closed server")
|
||||||
}
|
}
|
||||||
@ -524,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv1, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -544,8 +571,8 @@ func TestCloseByServer(t *testing.T) {
|
|||||||
|
|
||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||||
err = relayClient.Connect()
|
err = relayClient.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to connect to server: %s", err)
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@ -567,7 +594,7 @@ func TestCloseByServer(t *testing.T) {
|
|||||||
log.Fatalf("timeout waiting for client to disconnect")
|
log.Fatalf("timeout waiting for client to disconnect")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = relayClient.OpenConn("bob")
|
_, err = relayClient.OpenConn(ctx, "bob")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("unexpected opening connection to closed server")
|
t.Errorf("unexpected opening connection to closed server")
|
||||||
}
|
}
|
||||||
@ -577,7 +604,7 @@ func TestCloseByClient(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -596,8 +623,8 @@ func TestCloseByClient(t *testing.T) {
|
|||||||
|
|
||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||||
err = relayClient.Connect()
|
err = relayClient.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to connect to server: %s", err)
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@ -607,7 +634,7 @@ func TestCloseByClient(t *testing.T) {
|
|||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = relayClient.OpenConn("bob")
|
_, err = relayClient.OpenConn(ctx, "bob")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("unexpected opening connection to closed server")
|
t.Errorf("unexpected opening connection to closed server")
|
||||||
}
|
}
|
||||||
@ -623,7 +650,7 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
|||||||
idAlice := "alice"
|
idAlice := "alice"
|
||||||
idBob := "bob"
|
idBob := "bob"
|
||||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -647,8 +674,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
|
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||||
err = clientAlice.Connect()
|
err = clientAlice.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@ -659,8 +686,8 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
|
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
|
||||||
err = clientBob.Connect()
|
err = clientBob.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@ -671,12 +698,12 @@ func TestCloseNotDrainedChannel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
connAliceToBob, err := clientAlice.OpenConn(idBob)
|
connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
connBobToAlice, err := clientBob.OpenConn(idAlice)
|
connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -3,13 +3,14 @@ package client
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Conn represent a connection to a relayed remote peer.
|
// Conn represent a connection to a relayed remote peer.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
client *Client
|
client *Client
|
||||||
dstID []byte
|
dstID messages.PeerID
|
||||||
dstStringID string
|
|
||||||
messageChan chan Msg
|
messageChan chan Msg
|
||||||
instanceURL *RelayAddr
|
instanceURL *RelayAddr
|
||||||
}
|
}
|
||||||
@ -17,14 +18,12 @@ type Conn struct {
|
|||||||
// NewConn creates a new connection to a relayed remote peer.
|
// NewConn creates a new connection to a relayed remote peer.
|
||||||
// client: the client instance, it used to send messages to the destination peer
|
// client: the client instance, it used to send messages to the destination peer
|
||||||
// dstID: the destination peer ID
|
// dstID: the destination peer ID
|
||||||
// dstStringID: the destination peer ID in string format
|
|
||||||
// messageChan: the channel where the messages will be received
|
// messageChan: the channel where the messages will be received
|
||||||
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
|
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
|
||||||
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
|
func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
client: client,
|
client: client,
|
||||||
dstID: dstID,
|
dstID: dstID,
|
||||||
dstStringID: dstStringID,
|
|
||||||
messageChan: messageChan,
|
messageChan: messageChan,
|
||||||
instanceURL: instanceURL,
|
instanceURL: instanceURL,
|
||||||
}
|
}
|
||||||
@ -33,7 +32,7 @@ func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||||
return c.client.writeTo(c, c.dstStringID, c.dstID, p)
|
return c.client.writeTo(c, c.dstID, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
@ -48,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
return c.client.closeConn(c, c.dstStringID)
|
return c.client.closeConn(c, c.dstID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) LocalAddr() net.Addr {
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
|
@ -80,7 +80,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool
|
|||||||
|
|
||||||
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
|
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
|
||||||
|
|
||||||
if err := rc.Connect(); err != nil {
|
if err := rc.Connect(parentCtx); err != nil {
|
||||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ type OnServerCloseListener func()
|
|||||||
// ManagerService is the interface for the relay manager.
|
// ManagerService is the interface for the relay manager.
|
||||||
type ManagerService interface {
|
type ManagerService interface {
|
||||||
Serve() error
|
Serve() error
|
||||||
OpenConn(serverAddress, peerKey string) (net.Conn, error)
|
OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
|
||||||
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
||||||
RelayInstanceAddress() (string, error)
|
RelayInstanceAddress() (string, error)
|
||||||
ServerURLs() []string
|
ServerURLs() []string
|
||||||
@ -123,7 +123,7 @@ func (m *Manager) Serve() error {
|
|||||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||||
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||||
m.relayClientMu.Lock()
|
m.relayClientMu.Lock()
|
||||||
defer m.relayClientMu.Unlock()
|
defer m.relayClientMu.Unlock()
|
||||||
|
|
||||||
@ -141,10 +141,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
)
|
)
|
||||||
if !foreign {
|
if !foreign {
|
||||||
log.Debugf("open peer connection via permanent server: %s", peerKey)
|
log.Debugf("open peer connection via permanent server: %s", peerKey)
|
||||||
netConn, err = m.relayClient.OpenConn(peerKey)
|
netConn, err = m.relayClient.OpenConn(ctx, peerKey)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("open peer connection via foreign server: %s", serverAddress)
|
log.Debugf("open peer connection via foreign server: %s", serverAddress)
|
||||||
netConn, err = m.openConnVia(serverAddress, peerKey)
|
netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -229,7 +229,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error {
|
|||||||
return m.tokenStore.UpdateToken(token)
|
return m.tokenStore.UpdateToken(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||||
// check if already has a connection to the desired relay server
|
// check if already has a connection to the desired relay server
|
||||||
m.relayClientsMutex.RLock()
|
m.relayClientsMutex.RLock()
|
||||||
rt, ok := m.relayClients[serverAddress]
|
rt, ok := m.relayClients[serverAddress]
|
||||||
@ -240,7 +240,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
if rt.err != nil {
|
if rt.err != nil {
|
||||||
return nil, rt.err
|
return nil, rt.err
|
||||||
}
|
}
|
||||||
return rt.relayClient.OpenConn(peerKey)
|
return rt.relayClient.OpenConn(ctx, peerKey)
|
||||||
}
|
}
|
||||||
m.relayClientsMutex.RUnlock()
|
m.relayClientsMutex.RUnlock()
|
||||||
|
|
||||||
@ -255,7 +255,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
if rt.err != nil {
|
if rt.err != nil {
|
||||||
return nil, rt.err
|
return nil, rt.err
|
||||||
}
|
}
|
||||||
return rt.relayClient.OpenConn(peerKey)
|
return rt.relayClient.OpenConn(ctx, peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a new relay client and store it in the relayClients map
|
// create a new relay client and store it in the relayClients map
|
||||||
@ -264,8 +264,8 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
m.relayClients[serverAddress] = rt
|
m.relayClients[serverAddress] = rt
|
||||||
m.relayClientsMutex.Unlock()
|
m.relayClientsMutex.Unlock()
|
||||||
|
|
||||||
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
|
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID)
|
||||||
err := relayClient.Connect()
|
err := relayClient.Connect(m.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rt.err = err
|
rt.err = err
|
||||||
rt.Unlock()
|
rt.Unlock()
|
||||||
@ -279,7 +279,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
|||||||
rt.relayClient = relayClient
|
rt.relayClient = relayClient
|
||||||
rt.Unlock()
|
rt.Unlock()
|
||||||
|
|
||||||
conn, err := relayClient.OpenConn(peerKey)
|
conn, err := relayClient.OpenConn(ctx, peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/auth/allow"
|
||||||
"github.com/netbirdio/netbird/relay/server"
|
"github.com/netbirdio/netbird/relay/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,16 +23,22 @@ func TestEmptyURL(t *testing.T) {
|
|||||||
func TestForeignConn(t *testing.T) {
|
func TestForeignConn(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg1 := server.ListenerConfig{
|
lstCfg1 := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
|
||||||
|
srv1, err := server.NewServer(server.Config{
|
||||||
|
Meter: otel.Meter(""),
|
||||||
|
ExposedAddress: lstCfg1.Address,
|
||||||
|
TLSSupport: false,
|
||||||
|
AuthValidator: &allow.Auth{},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
err := srv1.Listen(srvCfg1)
|
err := srv1.Listen(lstCfg1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}
|
}
|
||||||
@ -51,7 +58,12 @@ func TestForeignConn(t *testing.T) {
|
|||||||
srvCfg2 := server.ListenerConfig{
|
srvCfg2 := server.ListenerConfig{
|
||||||
Address: "localhost:2234",
|
Address: "localhost:2234",
|
||||||
}
|
}
|
||||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
srv2, err := server.NewServer(server.Config{
|
||||||
|
Meter: otel.Meter(""),
|
||||||
|
ExposedAddress: srvCfg2.Address,
|
||||||
|
TLSSupport: false,
|
||||||
|
AuthValidator: &allow.Auth{},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -74,32 +86,26 @@ func TestForeignConn(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idAlice := "alice"
|
|
||||||
log.Debugf("connect by alice")
|
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice")
|
||||||
err = clientAlice.Serve()
|
if err := clientAlice.Serve(); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to serve manager: %s", err)
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idBob := "bob"
|
clientBob := NewManager(mCtx, toURL(srvCfg2), "bob")
|
||||||
log.Debugf("connect by bob")
|
if err := clientBob.Serve(); err != nil {
|
||||||
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
|
|
||||||
err = clientBob.Serve()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to serve manager: %s", err)
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
}
|
}
|
||||||
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
|
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get relay address: %s", err)
|
t.Fatalf("failed to get relay address: %s", err)
|
||||||
}
|
}
|
||||||
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
|
connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
|
connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -137,7 +143,7 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
srvCfg1 := server.ListenerConfig{
|
srvCfg1 := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
srv1, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -163,7 +169,7 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
srvCfg2 := server.ListenerConfig{
|
srvCfg2 := server.ListenerConfig{
|
||||||
Address: "localhost:2234",
|
Address: "localhost:2234",
|
||||||
}
|
}
|
||||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
srv2, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -186,16 +192,20 @@ func TestForeginConnClose(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idAlice := "alice"
|
|
||||||
log.Debugf("connect by alice")
|
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
|
||||||
|
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob")
|
||||||
|
if err := mgrBob.Serve(); err != nil {
|
||||||
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := NewManager(mCtx, toURL(srvCfg1), "alice")
|
||||||
err = mgr.Serve()
|
err = mgr.Serve()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to serve manager: %s", err)
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
}
|
}
|
||||||
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
|
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -212,7 +222,7 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
srvCfg1 := server.ListenerConfig{
|
srvCfg1 := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
srv1, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -241,7 +251,7 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
srvCfg2 := server.ListenerConfig{
|
srvCfg2 := server.ListenerConfig{
|
||||||
Address: "localhost:2234",
|
Address: "localhost:2234",
|
||||||
}
|
}
|
||||||
srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
|
srv2, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -277,7 +287,7 @@ func TestForeginAutoClose(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Log("open connection to another peer")
|
t.Log("open connection to another peer")
|
||||||
conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
|
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -305,7 +315,7 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
srvCfg := server.ListenerConfig{
|
srvCfg := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av)
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -330,6 +340,13 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
|
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
clientBob := NewManager(mCtx, toURL(srvCfg), "bob")
|
||||||
|
err = clientBob.Serve()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
|
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
|
||||||
err = clientAlice.Serve()
|
err = clientAlice.Serve()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -339,7 +356,7 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get relay address: %s", err)
|
t.Errorf("failed to get relay address: %s", err)
|
||||||
}
|
}
|
||||||
conn, err := clientAlice.OpenConn(ra, "bob")
|
conn, err := clientAlice.OpenConn(ctx, ra, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to bind channel: %s", err)
|
t.Errorf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -357,7 +374,7 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||||
|
|
||||||
log.Infof("reopent the connection")
|
log.Infof("reopent the connection")
|
||||||
_, err = clientAlice.OpenConn(ra, "bob")
|
_, err = clientAlice.OpenConn(ctx, ra, "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to open channel: %s", err)
|
t.Errorf("failed to open channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -366,24 +383,27 @@ func TestAutoReconnect(t *testing.T) {
|
|||||||
func TestNotifierDoubleAdd(t *testing.T) {
|
func TestNotifierDoubleAdd(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
srvCfg1 := server.ListenerConfig{
|
listenerCfg1 := server.ListenerConfig{
|
||||||
Address: "localhost:1234",
|
Address: "localhost:1234",
|
||||||
}
|
}
|
||||||
srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
|
srv, err := server.NewServer(server.Config{
|
||||||
|
Meter: otel.Meter(""),
|
||||||
|
ExposedAddress: listenerCfg1.Address,
|
||||||
|
TLSSupport: false,
|
||||||
|
AuthValidator: &allow.Auth{},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
err := srv1.Listen(srvCfg1)
|
if err := srv.Listen(listenerCfg1); err != nil {
|
||||||
if err != nil {
|
|
||||||
errChan <- err
|
errChan <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := srv1.Shutdown(ctx)
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to close server: %s", err)
|
t.Errorf("failed to close server: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -392,17 +412,21 @@ func TestNotifierDoubleAdd(t *testing.T) {
|
|||||||
t.Fatalf("failed to start server: %s", err)
|
t.Fatalf("failed to start server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idAlice := "alice"
|
|
||||||
log.Debugf("connect by alice")
|
log.Debugf("connect by alice")
|
||||||
mCtx, cancel := context.WithCancel(ctx)
|
mCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
|
||||||
err = clientAlice.Serve()
|
clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob")
|
||||||
if err != nil {
|
if err = clientBob.Serve(); err != nil {
|
||||||
t.Fatalf("failed to serve manager: %s", err)
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob")
|
clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice")
|
||||||
|
if err = clientAlice.Serve(); err != nil {
|
||||||
|
t.Fatalf("failed to serve manager: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
168
relay/client/peer_subscription.go
Normal file
168
relay/client/peer_subscription.go
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OpenConnectionTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayedConnWriter interface {
|
||||||
|
Write(p []byte) (n int, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeersStateSubscription manages subscriptions to peer state changes (online/offline)
|
||||||
|
// over a relay connection. It allows tracking peers' availability and handling offline
|
||||||
|
// events via a callback. We get online notification from the server only once.
|
||||||
|
type PeersStateSubscription struct {
|
||||||
|
log *log.Entry
|
||||||
|
relayConn relayedConnWriter
|
||||||
|
offlineCallback func(peerIDs []messages.PeerID)
|
||||||
|
|
||||||
|
listenForOfflinePeers map[messages.PeerID]struct{}
|
||||||
|
waitingPeers map[messages.PeerID]chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription {
|
||||||
|
return &PeersStateSubscription{
|
||||||
|
log: log,
|
||||||
|
relayConn: relayConn,
|
||||||
|
offlineCallback: offlineCallback,
|
||||||
|
listenForOfflinePeers: make(map[messages.PeerID]struct{}),
|
||||||
|
waitingPeers: make(map[messages.PeerID]chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnPeersOnline should be called when a notification is received that certain peers have come online.
|
||||||
|
// It checks if any of the peers are being waited on and signals their availability.
|
||||||
|
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
|
||||||
|
for _, peerID := range peersID {
|
||||||
|
waitCh, ok := s.waitingPeers[peerID]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
close(waitCh)
|
||||||
|
delete(s.waitingPeers, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
|
||||||
|
relevantPeers := make([]messages.PeerID, 0, len(peersID))
|
||||||
|
for _, peerID := range peersID {
|
||||||
|
if _, ok := s.listenForOfflinePeers[peerID]; ok {
|
||||||
|
relevantPeers = append(relevantPeers, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(relevantPeers) > 0 {
|
||||||
|
s.offlineCallback(relevantPeers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
|
||||||
|
// todo: when we unsubscribe while this is running, this will not return with error
|
||||||
|
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
|
||||||
|
// Check if already waiting for this peer
|
||||||
|
if _, exists := s.waitingPeers[peerID]; exists {
|
||||||
|
return errors.New("already waiting for peer to come online")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a channel to wait for the peer to come online
|
||||||
|
waitCh := make(chan struct{})
|
||||||
|
s.waitingPeers[peerID] = waitCh
|
||||||
|
|
||||||
|
if err := s.subscribeStateChange([]messages.PeerID{peerID}); err != nil {
|
||||||
|
s.log.Errorf("failed to subscribe to peer state: %s", err)
|
||||||
|
close(waitCh)
|
||||||
|
delete(s.waitingPeers, peerID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
|
||||||
|
close(waitCh)
|
||||||
|
delete(s.waitingPeers, peerID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for peer to come online or context to be cancelled
|
||||||
|
timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case <-waitCh:
|
||||||
|
s.log.Debugf("peer %s is now online", peerID)
|
||||||
|
return nil
|
||||||
|
case <-timeoutCtx.Done():
|
||||||
|
s.log.Debugf("context timed out while waiting for peer %s to come online", peerID)
|
||||||
|
if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil {
|
||||||
|
s.log.Errorf("failed to unsubscribe from peer state: %s", err)
|
||||||
|
}
|
||||||
|
return timeoutCtx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
|
||||||
|
msgErr := s.unsubscribeStateChange(peerIDs)
|
||||||
|
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
if wch, ok := s.waitingPeers[peerID]; ok {
|
||||||
|
close(wch)
|
||||||
|
delete(s.waitingPeers, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(s.listenForOfflinePeers, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return msgErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeersStateSubscription) Cleanup() {
|
||||||
|
for _, waitCh := range s.waitingPeers {
|
||||||
|
close(waitCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.waitingPeers = make(map[messages.PeerID]chan struct{})
|
||||||
|
s.listenForOfflinePeers = make(map[messages.PeerID]struct{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeersStateSubscription) subscribeStateChange(peerIDs []messages.PeerID) error {
|
||||||
|
msgs, err := messages.MarshalSubPeerStateMsg(peerIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range peerIDs {
|
||||||
|
s.listenForOfflinePeers[peer] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
if _, err := s.relayConn.Write(msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeersStateSubscription) unsubscribeStateChange(peerIDs []messages.PeerID) error {
|
||||||
|
msgs, err := messages.MarshalUnsubPeerStateMsg(peerIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var connWriteErr error
|
||||||
|
for _, msg := range msgs {
|
||||||
|
if _, err := s.relayConn.Write(msg); err != nil {
|
||||||
|
connWriteErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return connWriteErr
|
||||||
|
}
|
99
relay/client/peer_subscription_test.go
Normal file
99
relay/client/peer_subscription_test.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockRelayedConn struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRelayedConn) Write(p []byte) (n int, err error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) {
|
||||||
|
peerID := messages.HashID("peer1")
|
||||||
|
mockConn := &mockRelayedConn{}
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.SetOutput(&bytes.Buffer{}) // discard log output
|
||||||
|
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Launch wait in background
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
sub.OnPeersOnline([]messages.PeerID{peerID})
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) {
|
||||||
|
peerID := messages.HashID("peer2")
|
||||||
|
mockConn := &mockRelayedConn{}
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.SetOutput(&bytes.Buffer{})
|
||||||
|
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, context.DeadlineExceeded, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) {
|
||||||
|
peerID := messages.HashID("peer3")
|
||||||
|
mockConn := &mockRelayedConn{}
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.SetOutput(&bytes.Buffer{})
|
||||||
|
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
go func() {
|
||||||
|
_ = sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||||
|
|
||||||
|
}()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "already waiting")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnsubscribeStateChange(t *testing.T) {
|
||||||
|
peerID := messages.HashID("peer4")
|
||||||
|
mockConn := &mockRelayedConn{}
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.SetOutput(&bytes.Buffer{})
|
||||||
|
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||||
|
|
||||||
|
doneChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_ = sub.WaitToBeOnlineAndSubscribe(context.Background(), peerID)
|
||||||
|
close(doneChan)
|
||||||
|
}()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
err := sub.UnsubscribeStateChange([]messages.PeerID{peerID})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-doneChan:
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
// Expected timeout, meaning the subscription was successfully unsubscribed
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
}
|
@ -70,8 +70,8 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
|||||||
|
|
||||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||||
log.Infof("try to connecting to relay server: %s", url)
|
log.Infof("try to connecting to relay server: %s", url)
|
||||||
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
|
relayClient := NewClient(url, sp.TokenStore, sp.PeerID)
|
||||||
err := relayClient.Connect()
|
err := relayClient.Connect(ctx)
|
||||||
resultChan <- connResult{
|
resultChan <- connResult{
|
||||||
RelayClient: relayClient,
|
RelayClient: relayClient,
|
||||||
Url: url,
|
Url: url,
|
||||||
|
@ -141,7 +141,14 @@ func execute(cmd *cobra.Command, args []string) error {
|
|||||||
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
||||||
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
||||||
|
|
||||||
srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
|
cfg := server.Config{
|
||||||
|
Meter: metricsServer.Meter,
|
||||||
|
ExposedAddress: cobraConfig.ExposedAddress,
|
||||||
|
AuthValidator: authenticator,
|
||||||
|
TLSSupport: tlsSupport,
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, err := server.NewServer(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to create relay server: %v", err)
|
log.Debugf("failed to create relay server: %v", err)
|
||||||
return fmt.Errorf("failed to create relay server: %v", err)
|
return fmt.Errorf("failed to create relay server: %v", err)
|
||||||
|
@ -8,24 +8,24 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
prefixLength = 4
|
prefixLength = 4
|
||||||
IDSize = prefixLength + sha256.Size
|
peerIDSize = prefixLength + sha256.Size
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
prefix = []byte("sha-") // 4 bytes
|
prefix = []byte("sha-") // 4 bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
|
type PeerID [peerIDSize]byte
|
||||||
func HashID(peerID string) ([]byte, string) {
|
|
||||||
idHash := sha256.Sum256([]byte(peerID))
|
func (p PeerID) String() string {
|
||||||
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
|
return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:]))
|
||||||
var prefixedHash []byte
|
|
||||||
prefixedHash = append(prefixedHash, prefix...)
|
|
||||||
prefixedHash = append(prefixedHash, idHash[:]...)
|
|
||||||
return prefixedHash, idHashString
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HashIDToString converts a hash to a human-readable string
|
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
|
||||||
func HashIDToString(idHash []byte) string {
|
func HashID(peerID string) PeerID {
|
||||||
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
|
idHash := sha256.Sum256([]byte(peerID))
|
||||||
|
var prefixedHash [peerIDSize]byte
|
||||||
|
copy(prefixedHash[:prefixLength], prefix)
|
||||||
|
copy(prefixedHash[prefixLength:], idHash[:])
|
||||||
|
return prefixedHash
|
||||||
}
|
}
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
package messages
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestHashID(t *testing.T) {
|
|
||||||
hashedID, hashedStringId := HashID("alice")
|
|
||||||
enc := HashIDToString(hashedID)
|
|
||||||
if enc != hashedStringId {
|
|
||||||
t.Errorf("expected %s, got %s", hashedStringId, enc)
|
|
||||||
}
|
|
||||||
}
|
|
@ -9,19 +9,26 @@ import (
|
|||||||
const (
|
const (
|
||||||
MaxHandshakeSize = 212
|
MaxHandshakeSize = 212
|
||||||
MaxHandshakeRespSize = 8192
|
MaxHandshakeRespSize = 8192
|
||||||
|
MaxMessageSize = 8820
|
||||||
|
|
||||||
CurrentProtocolVersion = 1
|
CurrentProtocolVersion = 1
|
||||||
|
|
||||||
MsgTypeUnknown MsgType = 0
|
MsgTypeUnknown MsgType = 0
|
||||||
// Deprecated: Use MsgTypeAuth instead.
|
// Deprecated: Use MsgTypeAuth instead.
|
||||||
MsgTypeHello MsgType = 1
|
MsgTypeHello = 1
|
||||||
// Deprecated: Use MsgTypeAuthResponse instead.
|
// Deprecated: Use MsgTypeAuthResponse instead.
|
||||||
MsgTypeHelloResponse MsgType = 2
|
MsgTypeHelloResponse = 2
|
||||||
MsgTypeTransport MsgType = 3
|
MsgTypeTransport = 3
|
||||||
MsgTypeClose MsgType = 4
|
MsgTypeClose = 4
|
||||||
MsgTypeHealthCheck MsgType = 5
|
MsgTypeHealthCheck = 5
|
||||||
MsgTypeAuth = 6
|
MsgTypeAuth = 6
|
||||||
MsgTypeAuthResponse = 7
|
MsgTypeAuthResponse = 7
|
||||||
|
|
||||||
|
// Peers state messages
|
||||||
|
MsgTypeSubscribePeerState = 8
|
||||||
|
MsgTypeUnsubscribePeerState = 9
|
||||||
|
MsgTypePeersOnline = 10
|
||||||
|
MsgTypePeersWentOffline = 11
|
||||||
|
|
||||||
// base size of the message
|
// base size of the message
|
||||||
sizeOfVersionByte = 1
|
sizeOfVersionByte = 1
|
||||||
@ -30,17 +37,17 @@ const (
|
|||||||
|
|
||||||
// auth message
|
// auth message
|
||||||
sizeOfMagicByte = 4
|
sizeOfMagicByte = 4
|
||||||
headerSizeAuth = sizeOfMagicByte + IDSize
|
headerSizeAuth = sizeOfMagicByte + peerIDSize
|
||||||
offsetMagicByte = sizeOfProtoHeader
|
offsetMagicByte = sizeOfProtoHeader
|
||||||
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
|
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
|
||||||
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
|
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
|
||||||
|
|
||||||
// hello message
|
// hello message
|
||||||
headerSizeHello = sizeOfMagicByte + IDSize
|
headerSizeHello = sizeOfMagicByte + peerIDSize
|
||||||
headerSizeHelloResp = 0
|
headerSizeHelloResp = 0
|
||||||
|
|
||||||
// transport
|
// transport
|
||||||
headerSizeTransport = IDSize
|
headerSizeTransport = peerIDSize
|
||||||
offsetTransportID = sizeOfProtoHeader
|
offsetTransportID = sizeOfProtoHeader
|
||||||
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
|
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
|
||||||
)
|
)
|
||||||
@ -72,6 +79,14 @@ func (m MsgType) String() string {
|
|||||||
return "close"
|
return "close"
|
||||||
case MsgTypeHealthCheck:
|
case MsgTypeHealthCheck:
|
||||||
return "health check"
|
return "health check"
|
||||||
|
case MsgTypeSubscribePeerState:
|
||||||
|
return "subscribe peer state"
|
||||||
|
case MsgTypeUnsubscribePeerState:
|
||||||
|
return "unsubscribe peer state"
|
||||||
|
case MsgTypePeersOnline:
|
||||||
|
return "peers online"
|
||||||
|
case MsgTypePeersWentOffline:
|
||||||
|
return "peers went offline"
|
||||||
default:
|
default:
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
@ -102,7 +117,9 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
|||||||
MsgTypeAuth,
|
MsgTypeAuth,
|
||||||
MsgTypeTransport,
|
MsgTypeTransport,
|
||||||
MsgTypeClose,
|
MsgTypeClose,
|
||||||
MsgTypeHealthCheck:
|
MsgTypeHealthCheck,
|
||||||
|
MsgTypeSubscribePeerState,
|
||||||
|
MsgTypeUnsubscribePeerState:
|
||||||
return msgType, nil
|
return msgType, nil
|
||||||
default:
|
default:
|
||||||
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
||||||
@ -122,7 +139,9 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
|||||||
MsgTypeAuthResponse,
|
MsgTypeAuthResponse,
|
||||||
MsgTypeTransport,
|
MsgTypeTransport,
|
||||||
MsgTypeClose,
|
MsgTypeClose,
|
||||||
MsgTypeHealthCheck:
|
MsgTypeHealthCheck,
|
||||||
|
MsgTypePeersOnline,
|
||||||
|
MsgTypePeersWentOffline:
|
||||||
return msgType, nil
|
return msgType, nil
|
||||||
default:
|
default:
|
||||||
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
||||||
@ -135,11 +154,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
|||||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||||
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||||
// close the network connection without any response.
|
// close the network connection without any response.
|
||||||
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) {
|
||||||
if len(peerID) != IDSize {
|
|
||||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
|
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
|
||||||
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
@ -147,7 +162,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
|||||||
|
|
||||||
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||||
|
|
||||||
msg = append(msg, peerID...)
|
msg = append(msg, peerID[:]...)
|
||||||
msg = append(msg, additions...)
|
msg = append(msg, additions...)
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
@ -156,7 +171,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
|
|||||||
// Deprecated: Use UnmarshalAuthMsg instead.
|
// Deprecated: Use UnmarshalAuthMsg instead.
|
||||||
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
||||||
// authenticate the client with the server.
|
// authenticate the client with the server.
|
||||||
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) {
|
||||||
if len(msg) < sizeOfProtoHeader+headerSizeHello {
|
if len(msg) < sizeOfProtoHeader+headerSizeHello {
|
||||||
return nil, nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
@ -164,7 +179,9 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
|
|||||||
return nil, nil, errors.New("invalid magic header")
|
return nil, nil, errors.New("invalid magic header")
|
||||||
}
|
}
|
||||||
|
|
||||||
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil
|
peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello])
|
||||||
|
|
||||||
|
return &peerID, msg[headerSizeHello:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use MarshalAuthResponse instead.
|
// Deprecated: Use MarshalAuthResponse instead.
|
||||||
@ -197,34 +214,33 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
|||||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||||
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||||
// close the network connection without any response.
|
// close the network connection without any response.
|
||||||
func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
|
func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) {
|
||||||
if len(peerID) != IDSize {
|
if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize {
|
||||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
return nil, fmt.Errorf("too large auth payload")
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload))
|
msg := make([]byte, headerTotalSizeAuth+len(authPayload))
|
||||||
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
msg[1] = byte(MsgTypeAuth)
|
msg[1] = byte(MsgTypeAuth)
|
||||||
|
|
||||||
copy(msg[sizeOfProtoHeader:], magicHeader)
|
copy(msg[sizeOfProtoHeader:], magicHeader)
|
||||||
|
copy(msg[offsetAuthPeerID:], peerID[:])
|
||||||
msg = append(msg, peerID...)
|
copy(msg[headerTotalSizeAuth:], authPayload)
|
||||||
msg = append(msg, authPayload...)
|
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
||||||
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
|
func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) {
|
||||||
if len(msg) < headerTotalSizeAuth {
|
if len(msg) < headerTotalSizeAuth {
|
||||||
return nil, nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate the magic header
|
||||||
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
|
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
|
||||||
return nil, nil, errors.New("invalid magic header")
|
return nil, nil, errors.New("invalid magic header")
|
||||||
}
|
}
|
||||||
|
|
||||||
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil
|
peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth])
|
||||||
|
return &peerID, msg[headerTotalSizeAuth:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalAuthResponse creates a response message to the auth.
|
// MarshalAuthResponse creates a response message to the auth.
|
||||||
@ -268,45 +284,48 @@ func MarshalCloseMsg() []byte {
|
|||||||
// MarshalTransportMsg creates a transport message.
|
// MarshalTransportMsg creates a transport message.
|
||||||
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
||||||
// destination peer hashed ID.
|
// destination peer hashed ID.
|
||||||
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) {
|
func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) {
|
||||||
if len(peerID) != IDSize {
|
// todo validate size
|
||||||
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
|
msg := make([]byte, headerTotalSizeTransport+len(payload))
|
||||||
}
|
|
||||||
|
|
||||||
msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
|
|
||||||
msg[0] = byte(CurrentProtocolVersion)
|
msg[0] = byte(CurrentProtocolVersion)
|
||||||
msg[1] = byte(MsgTypeTransport)
|
msg[1] = byte(MsgTypeTransport)
|
||||||
copy(msg[sizeOfProtoHeader:], peerID)
|
copy(msg[sizeOfProtoHeader:], peerID[:])
|
||||||
msg = append(msg, payload...)
|
copy(msg[sizeOfProtoHeader+peerIDSize:], payload)
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
||||||
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) {
|
||||||
if len(buf) < headerTotalSizeTransport {
|
if len(buf) < headerTotalSizeTransport {
|
||||||
return nil, nil, ErrInvalidMessageLength
|
return nil, nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil
|
const offsetEnd = offsetTransportID + peerIDSize
|
||||||
|
var peerID PeerID
|
||||||
|
copy(peerID[:], buf[offsetTransportID:offsetEnd])
|
||||||
|
return &peerID, buf[headerTotalSizeTransport:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalTransportID extracts the peerID from the transport message.
|
// UnmarshalTransportID extracts the peerID from the transport message.
|
||||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
func UnmarshalTransportID(buf []byte) (*PeerID, error) {
|
||||||
if len(buf) < headerTotalSizeTransport {
|
if len(buf) < headerTotalSizeTransport {
|
||||||
return nil, ErrInvalidMessageLength
|
return nil, ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
return buf[offsetTransportID:headerTotalSizeTransport], nil
|
|
||||||
|
const offsetEnd = offsetTransportID + peerIDSize
|
||||||
|
var id PeerID
|
||||||
|
copy(id[:], buf[offsetTransportID:offsetEnd])
|
||||||
|
return &id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateTransportMsg updates the peerID in the transport message.
|
// UpdateTransportMsg updates the peerID in the transport message.
|
||||||
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
||||||
// need to allocate a new byte slice.
|
// need to allocate a new byte slice.
|
||||||
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
func UpdateTransportMsg(msg []byte, peerID PeerID) error {
|
||||||
if len(msg) < offsetTransportID+len(peerID) {
|
if len(msg) < offsetTransportID+peerIDSize {
|
||||||
return ErrInvalidMessageLength
|
return ErrInvalidMessageLength
|
||||||
}
|
}
|
||||||
copy(msg[offsetTransportID:], peerID)
|
copy(msg[offsetTransportID:], peerID[:])
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestMarshalHelloMsg(t *testing.T) {
|
func TestMarshalHelloMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
msg, err := MarshalHelloMsg(peerID, nil)
|
msg, err := MarshalHelloMsg(peerID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
@ -24,13 +24,13 @@ func TestMarshalHelloMsg(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
if string(receivedPeerID) != string(peerID) {
|
if receivedPeerID.String() != peerID.String() {
|
||||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMarshalAuthMsg(t *testing.T) {
|
func TestMarshalAuthMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
msg, err := MarshalAuthMsg(peerID, []byte{})
|
msg, err := MarshalAuthMsg(peerID, []byte{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
@ -49,7 +49,7 @@ func TestMarshalAuthMsg(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
if string(receivedPeerID) != string(peerID) {
|
if receivedPeerID.String() != peerID.String() {
|
||||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -80,7 +80,7 @@ func TestMarshalAuthResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMarshalTransportMsg(t *testing.T) {
|
func TestMarshalTransportMsg(t *testing.T) {
|
||||||
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||||
payload := []byte("payload")
|
payload := []byte("payload")
|
||||||
msg, err := MarshalTransportMsg(peerID, payload)
|
msg, err := MarshalTransportMsg(peerID, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -101,7 +101,7 @@ func TestMarshalTransportMsg(t *testing.T) {
|
|||||||
t.Fatalf("failed to unmarshal transport id: %v", err)
|
t.Fatalf("failed to unmarshal transport id: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(uPeerID) != string(peerID) {
|
if uPeerID.String() != peerID.String() {
|
||||||
t.Errorf("expected %s, got %s", peerID, uPeerID)
|
t.Errorf("expected %s, got %s", peerID, uPeerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,8 +110,8 @@ func TestMarshalTransportMsg(t *testing.T) {
|
|||||||
t.Fatalf("error: %v", err)
|
t.Fatalf("error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(id) != string(peerID) {
|
if id.String() != peerID.String() {
|
||||||
t.Errorf("expected %s, got %s", peerID, id)
|
t.Errorf("expected: '%s', got: '%s'", peerID, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(respPayload) != string(payload) {
|
if string(respPayload) != string(payload) {
|
||||||
|
92
relay/messages/peer_state.go
Normal file
92
relay/messages/peer_state.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) {
|
||||||
|
return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState))
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) {
|
||||||
|
return unmarshalPeerIDs(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) {
|
||||||
|
return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState))
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) {
|
||||||
|
return unmarshalPeerIDs(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalPeersOnline(ids []PeerID) ([][]byte, error) {
|
||||||
|
return marshalPeerIDs(ids, byte(MsgTypePeersOnline))
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) {
|
||||||
|
return unmarshalPeerIDs(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) {
|
||||||
|
return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline))
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) {
|
||||||
|
return unmarshalPeerIDs(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type
|
||||||
|
func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil, fmt.Errorf("no list of peer ids provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize
|
||||||
|
var messages [][]byte
|
||||||
|
|
||||||
|
for i := 0; i < len(ids); i += maxPeersPerMessage {
|
||||||
|
end := i + maxPeersPerMessage
|
||||||
|
if end > len(ids) {
|
||||||
|
end = len(ids)
|
||||||
|
}
|
||||||
|
chunk := ids[i:end]
|
||||||
|
|
||||||
|
totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize
|
||||||
|
buf := make([]byte, totalSize)
|
||||||
|
buf[0] = byte(CurrentProtocolVersion)
|
||||||
|
buf[1] = msgType
|
||||||
|
|
||||||
|
offset := sizeOfProtoHeader
|
||||||
|
for _, id := range chunk {
|
||||||
|
copy(buf[offset:], id[:])
|
||||||
|
offset += peerIDSize
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer
|
||||||
|
func unmarshalPeerIDs(buf []byte) ([]PeerID, error) {
|
||||||
|
if len(buf) < sizeOfProtoHeader {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 {
|
||||||
|
return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize
|
||||||
|
|
||||||
|
ids := make([]PeerID, numIDs)
|
||||||
|
offset := sizeOfProtoHeader
|
||||||
|
for i := 0; i < numIDs; i++ {
|
||||||
|
copy(ids[i][:], buf[offset:offset+peerIDSize])
|
||||||
|
offset += peerIDSize
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids, nil
|
||||||
|
}
|
144
relay/messages/peer_state_test.go
Normal file
144
relay/messages/peer_state_test.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testPeerCount = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper function to generate test PeerIDs
|
||||||
|
func generateTestPeerIDs(n int) []PeerID {
|
||||||
|
ids := make([]PeerID, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
for j := 0; j < peerIDSize; j++ {
|
||||||
|
ids[i][j] = byte(i + j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to compare slices of PeerID
|
||||||
|
func peerIDEqual(a, b []PeerID) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if !bytes.Equal(a[i][:], b[i][:]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalUnmarshalSubPeerState(t *testing.T) {
|
||||||
|
ids := generateTestPeerIDs(testPeerCount)
|
||||||
|
|
||||||
|
msgs, err := MarshalSubPeerStateMsg(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allIDs []PeerID
|
||||||
|
for _, msg := range msgs {
|
||||||
|
decoded, err := UnmarshalSubPeerStateMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
allIDs = append(allIDs, decoded...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !peerIDEqual(ids, allIDs) {
|
||||||
|
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalSubPeerState_EmptyInput(t *testing.T) {
|
||||||
|
_, err := MarshalSubPeerStateMsg([]PeerID{})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for empty input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalSubPeerState_Invalid(t *testing.T) {
|
||||||
|
// Too short
|
||||||
|
_, err := UnmarshalSubPeerStateMsg([]byte{1})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for short input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Misaligned length
|
||||||
|
buf := make([]byte, sizeOfProtoHeader+1)
|
||||||
|
_, err = UnmarshalSubPeerStateMsg(buf)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for misaligned input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalUnmarshalPeersOnline(t *testing.T) {
|
||||||
|
ids := generateTestPeerIDs(testPeerCount)
|
||||||
|
|
||||||
|
msgs, err := MarshalPeersOnline(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allIDs []PeerID
|
||||||
|
for _, msg := range msgs {
|
||||||
|
decoded, err := UnmarshalPeersOnlineMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
allIDs = append(allIDs, decoded...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !peerIDEqual(ids, allIDs) {
|
||||||
|
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalPeersOnline_EmptyInput(t *testing.T) {
|
||||||
|
_, err := MarshalPeersOnline([]PeerID{})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for empty input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalPeersOnline_Invalid(t *testing.T) {
|
||||||
|
_, err := UnmarshalPeersOnlineMsg([]byte{1})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for short input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalUnmarshalPeersWentOffline(t *testing.T) {
|
||||||
|
ids := generateTestPeerIDs(testPeerCount)
|
||||||
|
|
||||||
|
msgs, err := MarshalPeersWentOffline(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allIDs []PeerID
|
||||||
|
for _, msg := range msgs {
|
||||||
|
// MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline
|
||||||
|
decoded, err := UnmarshalPeersOnlineMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
allIDs = append(allIDs, decoded...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !peerIDEqual(ids, allIDs) {
|
||||||
|
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) {
|
||||||
|
_, err := MarshalPeersWentOffline([]PeerID{})
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error for empty input")
|
||||||
|
}
|
||||||
|
}
|
@ -6,7 +6,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/auth"
|
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
"github.com/netbirdio/netbird/relay/messages/address"
|
"github.com/netbirdio/netbird/relay/messages/address"
|
||||||
@ -14,6 +13,12 @@ import (
|
|||||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Validator interface {
|
||||||
|
Validate(any) error
|
||||||
|
// Deprecated: Use Validate instead.
|
||||||
|
ValidateHelloMsgType(any) error
|
||||||
|
}
|
||||||
|
|
||||||
// preparedMsg contains the marshalled success response messages
|
// preparedMsg contains the marshalled success response messages
|
||||||
type preparedMsg struct {
|
type preparedMsg struct {
|
||||||
responseHelloMsg []byte
|
responseHelloMsg []byte
|
||||||
@ -54,14 +59,14 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) {
|
|||||||
|
|
||||||
type handshake struct {
|
type handshake struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
validator auth.Validator
|
validator Validator
|
||||||
preparedMsg *preparedMsg
|
preparedMsg *preparedMsg
|
||||||
|
|
||||||
handshakeMethodAuth bool
|
handshakeMethodAuth bool
|
||||||
peerID string
|
peerID *messages.PeerID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) handshakeReceive() ([]byte, error) {
|
func (h *handshake) handshakeReceive() (*messages.PeerID, error) {
|
||||||
buf := make([]byte, messages.MaxHandshakeSize)
|
buf := make([]byte, messages.MaxHandshakeSize)
|
||||||
n, err := h.conn.Read(buf)
|
n, err := h.conn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -80,17 +85,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var peerID *messages.PeerID
|
||||||
bytePeerID []byte
|
|
||||||
peerID string
|
|
||||||
)
|
|
||||||
switch msgType {
|
switch msgType {
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
case messages.MsgTypeHello:
|
case messages.MsgTypeHello:
|
||||||
bytePeerID, peerID, err = h.handleHelloMsg(buf)
|
peerID, err = h.handleHelloMsg(buf)
|
||||||
case messages.MsgTypeAuth:
|
case messages.MsgTypeAuth:
|
||||||
h.handshakeMethodAuth = true
|
h.handshakeMethodAuth = true
|
||||||
bytePeerID, peerID, err = h.handleAuthMsg(buf)
|
peerID, err = h.handleAuthMsg(buf)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
@ -98,7 +100,7 @@ func (h *handshake) handshakeReceive() ([]byte, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
h.peerID = peerID
|
h.peerID = peerID
|
||||||
return bytePeerID, nil
|
return peerID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) handshakeResponse() error {
|
func (h *handshake) handshakeResponse() error {
|
||||||
@ -116,40 +118,37 @@ func (h *handshake) handshakeResponse() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) {
|
func (h *handshake) handleHelloMsg(buf []byte) (*messages.PeerID, error) {
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
peerID, authData, err := messages.UnmarshalHelloMsg(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerID := messages.HashIDToString(rawPeerID)
|
|
||||||
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
|
log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr())
|
||||||
|
|
||||||
authMsg, err := authmsg.UnmarshalMsg(authData)
|
authMsg, err := authmsg.UnmarshalMsg(authData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("unmarshal auth message: %w", err)
|
return nil, fmt.Errorf("unmarshal auth message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
|
||||||
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawPeerID, peerID, nil
|
return peerID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) {
|
func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) {
|
||||||
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("unmarshal hello message: %w", err)
|
return nil, fmt.Errorf("unmarshal hello message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerID := messages.HashIDToString(rawPeerID)
|
|
||||||
|
|
||||||
if err := h.validator.Validate(authPayload); err != nil {
|
if err := h.validator.Validate(authPayload); err != nil {
|
||||||
return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err)
|
return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rawPeerID, peerID, nil
|
return rawPeerID, nil
|
||||||
}
|
}
|
||||||
|
@ -12,43 +12,50 @@ import (
|
|||||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||||
"github.com/netbirdio/netbird/relay/messages"
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
bufferSize = 8820
|
bufferSize = messages.MaxMessageSize
|
||||||
|
|
||||||
errCloseConn = "failed to close connection to peer: %s"
|
errCloseConn = "failed to close connection to peer: %s"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Peer represents a peer connection
|
// Peer represents a peer connection
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
metrics *metrics.Metrics
|
metrics *metrics.Metrics
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
idS string
|
id messages.PeerID
|
||||||
idB []byte
|
conn net.Conn
|
||||||
conn net.Conn
|
connMu sync.RWMutex
|
||||||
connMu sync.RWMutex
|
store *store.Store
|
||||||
store *Store
|
notifier *store.PeerNotifier
|
||||||
|
|
||||||
|
peersListener *store.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeer creates a new Peer instance and prepare custom logging
|
// NewPeer creates a new Peer instance and prepare custom logging
|
||||||
func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer {
|
func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer {
|
||||||
stringID := messages.HashIDToString(id)
|
p := &Peer{
|
||||||
return &Peer{
|
metrics: metrics,
|
||||||
metrics: metrics,
|
log: log.WithField("peer_id", id.String()),
|
||||||
log: log.WithField("peer_id", stringID),
|
id: id,
|
||||||
idS: stringID,
|
conn: conn,
|
||||||
idB: id,
|
store: store,
|
||||||
conn: conn,
|
notifier: notifier,
|
||||||
store: store,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// Work reads data from the connection
|
// Work reads data from the connection
|
||||||
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
||||||
// the message accordingly.
|
// the message accordingly.
|
||||||
func (p *Peer) Work() {
|
func (p *Peer) Work() {
|
||||||
|
p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline)
|
||||||
defer func() {
|
defer func() {
|
||||||
|
p.notifier.RemoveListener(p.peersListener)
|
||||||
|
|
||||||
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
p.log.Errorf(errCloseConn, err)
|
p.log.Errorf(errCloseConn, err)
|
||||||
}
|
}
|
||||||
@ -94,6 +101,10 @@ func (p *Peer) Work() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Peer) ID() messages.PeerID {
|
||||||
|
return p.id
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
|
func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case messages.MsgTypeHealthCheck:
|
case messages.MsgTypeHealthCheck:
|
||||||
@ -107,6 +118,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
|
|||||||
if err := p.conn.Close(); err != nil {
|
if err := p.conn.Close(); err != nil {
|
||||||
log.Errorf(errCloseConn, err)
|
log.Errorf(errCloseConn, err)
|
||||||
}
|
}
|
||||||
|
case messages.MsgTypeSubscribePeerState:
|
||||||
|
p.handleSubscribePeerState(msg)
|
||||||
|
case messages.MsgTypeUnsubscribePeerState:
|
||||||
|
p.handleUnsubscribePeerState(msg)
|
||||||
default:
|
default:
|
||||||
p.log.Warnf("received unexpected message type: %s", msgType)
|
p.log.Warnf("received unexpected message type: %s", msgType)
|
||||||
}
|
}
|
||||||
@ -145,7 +160,7 @@ func (p *Peer) Close() {
|
|||||||
|
|
||||||
// String returns the peer ID
|
// String returns the peer ID
|
||||||
func (p *Peer) String() string {
|
func (p *Peer) String() string {
|
||||||
return p.idS
|
return p.id.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
|
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
|
||||||
@ -197,14 +212,14 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stringPeerID := messages.HashIDToString(peerID)
|
item, ok := p.store.Peer(*peerID)
|
||||||
dp, ok := p.store.Peer(stringPeerID)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
p.log.Debugf("peer not found: %s", stringPeerID)
|
p.log.Debugf("peer not found: %s", peerID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
dp := item.(*Peer)
|
||||||
|
|
||||||
err = messages.UpdateTransportMsg(msg, p.idB)
|
err = messages.UpdateTransportMsg(msg, p.id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.log.Errorf("failed to update transport message: %s", err)
|
p.log.Errorf("failed to update transport message: %s", err)
|
||||||
return
|
return
|
||||||
@ -217,3 +232,57 @@ func (p *Peer) handleTransportMsg(msg []byte) {
|
|||||||
}
|
}
|
||||||
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
|
p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Peer) handleSubscribePeerState(msg []byte) {
|
||||||
|
peerIDs, err := messages.UnmarshalSubPeerStateMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
p.log.Errorf("failed to unmarshal open connection message: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
|
||||||
|
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
|
||||||
|
if len(onlinePeers) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.log.Debugf("response with %d online peers", len(onlinePeers))
|
||||||
|
p.sendPeersOnline(onlinePeers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) handleUnsubscribePeerState(msg []byte) {
|
||||||
|
peerIDs, err := messages.UnmarshalUnsubPeerStateMsg(msg)
|
||||||
|
if err != nil {
|
||||||
|
p.log.Errorf("failed to unmarshal open connection message: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.peersListener.RemoveInterestedPeer(peerIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
|
||||||
|
msgs, err := messages.MarshalPeersOnline(peers)
|
||||||
|
if err != nil {
|
||||||
|
p.log.Errorf("failed to marshal peer location message: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, msg := range msgs {
|
||||||
|
if _, err := p.Write(msg); err != nil {
|
||||||
|
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
|
||||||
|
msgs, err := messages.MarshalPeersWentOffline(peers)
|
||||||
|
if err != nil {
|
||||||
|
p.log.Errorf("failed to marshal peer location message: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, msg := range msgs {
|
||||||
|
if _, err := p.Write(msg); err != nil {
|
||||||
|
p.log.Errorf("failed to write %d. peers offline message: %s", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -4,26 +4,55 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/auth"
|
|
||||||
//nolint:staticcheck
|
//nolint:staticcheck
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
"github.com/netbirdio/netbird/relay/metrics"
|
||||||
|
"github.com/netbirdio/netbird/relay/server/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Meter metric.Meter
|
||||||
|
ExposedAddress string
|
||||||
|
TLSSupport bool
|
||||||
|
AuthValidator Validator
|
||||||
|
|
||||||
|
instanceURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) validate() error {
|
||||||
|
if c.Meter == nil {
|
||||||
|
c.Meter = otel.Meter("")
|
||||||
|
}
|
||||||
|
if c.ExposedAddress == "" {
|
||||||
|
return fmt.Errorf("exposed address is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
instanceURL, err := getInstanceURL(c.ExposedAddress, c.TLSSupport)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid url: %v", err)
|
||||||
|
}
|
||||||
|
c.instanceURL = instanceURL
|
||||||
|
|
||||||
|
if c.AuthValidator == nil {
|
||||||
|
return fmt.Errorf("auth validator is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Relay represents the relay server
|
// Relay represents the relay server
|
||||||
type Relay struct {
|
type Relay struct {
|
||||||
metrics *metrics.Metrics
|
metrics *metrics.Metrics
|
||||||
metricsCancel context.CancelFunc
|
metricsCancel context.CancelFunc
|
||||||
validator auth.Validator
|
validator Validator
|
||||||
|
|
||||||
store *Store
|
store *store.Store
|
||||||
|
notifier *store.PeerNotifier
|
||||||
instanceURL string
|
instanceURL string
|
||||||
preparedMsg *preparedMsg
|
preparedMsg *preparedMsg
|
||||||
|
|
||||||
@ -31,40 +60,40 @@ type Relay struct {
|
|||||||
closeMu sync.RWMutex
|
closeMu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRelay creates a new Relay instance
|
// NewRelay creates and returns a new Relay instance.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage
|
//
|
||||||
// metrics for the relay server.
|
// config: A Config struct that holds the configuration needed to initialize the relay server.
|
||||||
// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this
|
// - Meter: A metric.Meter used for emitting metrics. If not set, a default no-op meter will be used.
|
||||||
// address as the relay server's instance URL.
|
// - ExposedAddress: The external address clients use to reach this relay. Required.
|
||||||
// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The
|
// - TLSSupport: A boolean indicating if the relay uses TLS. Affects the generated instance URL.
|
||||||
// instance URL depends on this value.
|
// - AuthValidator: A Validator implementation used to authenticate peers. Required.
|
||||||
// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
|
|
||||||
// peers.
|
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil.
|
//
|
||||||
// Otherwise, the error contains the details of what went wrong.
|
// A pointer to a Relay instance and an error. If initialization is successful, the error will be nil;
|
||||||
func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) {
|
// otherwise, it will contain the reason the relay could not be created (e.g., invalid configuration).
|
||||||
|
func NewRelay(config Config) (*Relay, error) {
|
||||||
|
if err := config.validate(); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ctx, metricsCancel := context.WithCancel(context.Background())
|
ctx, metricsCancel := context.WithCancel(context.Background())
|
||||||
m, err := metrics.NewMetrics(ctx, meter)
|
m, err := metrics.NewMetrics(ctx, config.Meter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
metricsCancel()
|
metricsCancel()
|
||||||
return nil, fmt.Errorf("creating app metrics: %v", err)
|
return nil, fmt.Errorf("creating app metrics: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peerStore := store.NewStore()
|
||||||
r := &Relay{
|
r := &Relay{
|
||||||
metrics: m,
|
metrics: m,
|
||||||
metricsCancel: metricsCancel,
|
metricsCancel: metricsCancel,
|
||||||
validator: validator,
|
validator: config.AuthValidator,
|
||||||
store: NewStore(),
|
instanceURL: config.instanceURL,
|
||||||
}
|
store: peerStore,
|
||||||
|
notifier: store.NewPeerNotifier(peerStore),
|
||||||
r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
|
|
||||||
if err != nil {
|
|
||||||
metricsCancel()
|
|
||||||
return nil, fmt.Errorf("get instance URL: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
|
||||||
@ -76,32 +105,6 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida
|
|||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
|
|
||||||
// provided address according to TLS definition and parses the address before returning it
|
|
||||||
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
|
||||||
addr := exposedAddress
|
|
||||||
split := strings.Split(exposedAddress, "://")
|
|
||||||
switch {
|
|
||||||
case len(split) == 1 && tlsSupported:
|
|
||||||
addr = "rels://" + exposedAddress
|
|
||||||
case len(split) == 1 && !tlsSupported:
|
|
||||||
addr = "rel://" + exposedAddress
|
|
||||||
case len(split) > 2:
|
|
||||||
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedURL, err := url.ParseRequestURI(addr)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("invalid exposed address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
|
|
||||||
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
return parsedURL.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accept start to handle a new peer connection
|
// Accept start to handle a new peer connection
|
||||||
func (r *Relay) Accept(conn net.Conn) {
|
func (r *Relay) Accept(conn net.Conn) {
|
||||||
acceptTime := time.Now()
|
acceptTime := time.Now()
|
||||||
@ -125,14 +128,17 @@ func (r *Relay) Accept(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peer := NewPeer(r.metrics, peerID, conn, r.store)
|
peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
|
||||||
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
|
||||||
storeTime := time.Now()
|
storeTime := time.Now()
|
||||||
r.store.AddPeer(peer)
|
r.store.AddPeer(peer)
|
||||||
|
r.notifier.PeerCameOnline(peer.ID())
|
||||||
|
|
||||||
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
|
||||||
r.metrics.PeerConnected(peer.String())
|
r.metrics.PeerConnected(peer.String())
|
||||||
go func() {
|
go func() {
|
||||||
peer.Work()
|
peer.Work()
|
||||||
|
r.notifier.PeerWentOffline(peer.ID())
|
||||||
r.store.DeletePeer(peer)
|
r.store.DeletePeer(peer)
|
||||||
peer.log.Debugf("relay connection closed")
|
peer.log.Debugf("relay connection closed")
|
||||||
r.metrics.PeerDisconnected(peer.String())
|
r.metrics.PeerDisconnected(peer.String())
|
||||||
@ -154,12 +160,12 @@ func (r *Relay) Shutdown(ctx context.Context) {
|
|||||||
|
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
peers := r.store.Peers()
|
peers := r.store.Peers()
|
||||||
for _, peer := range peers {
|
for _, v := range peers {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(p *Peer) {
|
go func(p *Peer) {
|
||||||
p.CloseGracefully(ctx)
|
p.CloseGracefully(ctx)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}(peer)
|
}(v.(*Peer))
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
r.metricsCancel()
|
r.metricsCancel()
|
||||||
|
@ -6,15 +6,12 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"go.opentelemetry.io/otel/metric"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/relay/auth"
|
|
||||||
"github.com/netbirdio/netbird/relay/server/listener"
|
"github.com/netbirdio/netbird/relay/server/listener"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener/quic"
|
"github.com/netbirdio/netbird/relay/server/listener/quic"
|
||||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||||
quictls "github.com/netbirdio/netbird/relay/tls"
|
quictls "github.com/netbirdio/netbird/relay/tls"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ListenerConfig is the configuration for the listener.
|
// ListenerConfig is the configuration for the listener.
|
||||||
@ -33,13 +30,22 @@ type Server struct {
|
|||||||
listeners []listener.Listener
|
listeners []listener.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new relay server instance.
|
// NewServer creates and returns a new relay server instance.
|
||||||
// meter: the OpenTelemetry meter
|
//
|
||||||
// exposedAddress: this address will be used as the instance URL. It should be a domain:port format.
|
// Parameters:
|
||||||
// tlsSupport: if true, the server will support TLS
|
//
|
||||||
// authValidator: the auth validator to use for the server
|
// config: A Config struct containing the necessary configuration:
|
||||||
func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) {
|
// - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used.
|
||||||
relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator)
|
// - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required.
|
||||||
|
// - TLSSupport: A boolean indicating whether TLS is enabled for the server.
|
||||||
|
// - AuthValidator: A Validator used to authenticate peers. Required.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
//
|
||||||
|
// A pointer to a Server instance and an error. If the configuration is valid and initialization succeeds,
|
||||||
|
// the returned error will be nil. Otherwise, the error will describe the problem.
|
||||||
|
func NewServer(config Config) (*Server, error) {
|
||||||
|
relay, err := NewRelay(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
121
relay/server/store/listener.go
Normal file
121
relay/server/store/listener.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Listener struct {
|
||||||
|
store *Store
|
||||||
|
|
||||||
|
onlineChan chan messages.PeerID
|
||||||
|
offlineChan chan messages.PeerID
|
||||||
|
interestedPeersForOffline map[messages.PeerID]struct{}
|
||||||
|
interestedPeersForOnline map[messages.PeerID]struct{}
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
listenerCtx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func newListener(store *Store) *Listener {
|
||||||
|
l := &Listener{
|
||||||
|
store: store,
|
||||||
|
|
||||||
|
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
|
||||||
|
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
|
||||||
|
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
|
||||||
|
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID {
|
||||||
|
availablePeers := make([]messages.PeerID, 0)
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
for _, id := range peerIDs {
|
||||||
|
l.interestedPeersForOnline[id] = struct{}{}
|
||||||
|
l.interestedPeersForOffline[id] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// collect online peers to response back to the caller
|
||||||
|
for _, id := range peerIDs {
|
||||||
|
_, ok := l.store.Peer(id)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
availablePeers = append(availablePeers, id)
|
||||||
|
}
|
||||||
|
return availablePeers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
for _, id := range peerIDs {
|
||||||
|
delete(l.interestedPeersForOffline, id)
|
||||||
|
delete(l.interestedPeersForOnline, id)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) listenForEvents(ctx context.Context, onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) {
|
||||||
|
l.listenerCtx = ctx
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case pID := <-l.onlineChan:
|
||||||
|
peers := make([]messages.PeerID, 0)
|
||||||
|
peers = append(peers, pID)
|
||||||
|
|
||||||
|
for len(l.onlineChan) > 0 {
|
||||||
|
pID = <-l.onlineChan
|
||||||
|
peers = append(peers, pID)
|
||||||
|
}
|
||||||
|
|
||||||
|
onPeersComeOnline(peers)
|
||||||
|
case pID := <-l.offlineChan:
|
||||||
|
peers := make([]messages.PeerID, 0)
|
||||||
|
peers = append(peers, pID)
|
||||||
|
|
||||||
|
for len(l.offlineChan) > 0 {
|
||||||
|
pID = <-l.offlineChan
|
||||||
|
peers = append(peers, pID)
|
||||||
|
}
|
||||||
|
|
||||||
|
onPeersWentOffline(peers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) peerWentOffline(peerID messages.PeerID) {
|
||||||
|
l.mu.RLock()
|
||||||
|
defer l.mu.RUnlock()
|
||||||
|
|
||||||
|
if _, ok := l.interestedPeersForOffline[peerID]; ok {
|
||||||
|
select {
|
||||||
|
case l.offlineChan <- peerID:
|
||||||
|
case <-l.listenerCtx.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Listener) peerComeOnline(peerID messages.PeerID) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := l.interestedPeersForOnline[peerID]; ok {
|
||||||
|
select {
|
||||||
|
case l.onlineChan <- peerID:
|
||||||
|
case <-l.listenerCtx.Done():
|
||||||
|
}
|
||||||
|
delete(l.interestedPeersForOnline, peerID)
|
||||||
|
}
|
||||||
|
}
|
64
relay/server/store/notifier.go
Normal file
64
relay/server/store/notifier.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PeerNotifier struct {
|
||||||
|
store *Store
|
||||||
|
|
||||||
|
listeners map[*Listener]context.CancelFunc
|
||||||
|
listenersMutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeerNotifier(store *Store) *PeerNotifier {
|
||||||
|
pn := &PeerNotifier{
|
||||||
|
store: store,
|
||||||
|
listeners: make(map[*Listener]context.CancelFunc),
|
||||||
|
}
|
||||||
|
return pn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
listener := newListener(pn.store)
|
||||||
|
go listener.listenForEvents(ctx, onPeersComeOnline, onPeersWentOffline)
|
||||||
|
|
||||||
|
pn.listenersMutex.Lock()
|
||||||
|
pn.listeners[listener] = cancel
|
||||||
|
pn.listenersMutex.Unlock()
|
||||||
|
return listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pn *PeerNotifier) RemoveListener(listener *Listener) {
|
||||||
|
pn.listenersMutex.Lock()
|
||||||
|
defer pn.listenersMutex.Unlock()
|
||||||
|
|
||||||
|
cancel, ok := pn.listeners[listener]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
delete(pn.listeners, listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pn *PeerNotifier) PeerWentOffline(peerID messages.PeerID) {
|
||||||
|
pn.listenersMutex.RLock()
|
||||||
|
defer pn.listenersMutex.RUnlock()
|
||||||
|
|
||||||
|
for listener := range pn.listeners {
|
||||||
|
listener.peerWentOffline(peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pn *PeerNotifier) PeerCameOnline(peerID messages.PeerID) {
|
||||||
|
pn.listenersMutex.RLock()
|
||||||
|
defer pn.listenersMutex.RUnlock()
|
||||||
|
|
||||||
|
for listener := range pn.listeners {
|
||||||
|
listener.peerComeOnline(peerID)
|
||||||
|
}
|
||||||
|
}
|
@ -1,41 +1,48 @@
|
|||||||
package server
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type IPeer interface {
|
||||||
|
Close()
|
||||||
|
ID() messages.PeerID
|
||||||
|
}
|
||||||
|
|
||||||
// Store is a thread-safe store of peers
|
// Store is a thread-safe store of peers
|
||||||
// It is used to store the peers that are connected to the relay server
|
// It is used to store the peers that are connected to the relay server
|
||||||
type Store struct {
|
type Store struct {
|
||||||
peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
|
peers map[messages.PeerID]IPeer
|
||||||
peersLock sync.RWMutex
|
peersLock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStore creates a new Store instance
|
// NewStore creates a new Store instance
|
||||||
func NewStore() *Store {
|
func NewStore() *Store {
|
||||||
return &Store{
|
return &Store{
|
||||||
peers: make(map[string]*Peer),
|
peers: make(map[messages.PeerID]IPeer),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeer adds a peer to the store
|
// AddPeer adds a peer to the store
|
||||||
func (s *Store) AddPeer(peer *Peer) {
|
func (s *Store) AddPeer(peer IPeer) {
|
||||||
s.peersLock.Lock()
|
s.peersLock.Lock()
|
||||||
defer s.peersLock.Unlock()
|
defer s.peersLock.Unlock()
|
||||||
odlPeer, ok := s.peers[peer.String()]
|
odlPeer, ok := s.peers[peer.ID()]
|
||||||
if ok {
|
if ok {
|
||||||
odlPeer.Close()
|
odlPeer.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
s.peers[peer.String()] = peer
|
s.peers[peer.ID()] = peer
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeer deletes a peer from the store
|
// DeletePeer deletes a peer from the store
|
||||||
func (s *Store) DeletePeer(peer *Peer) {
|
func (s *Store) DeletePeer(peer IPeer) {
|
||||||
s.peersLock.Lock()
|
s.peersLock.Lock()
|
||||||
defer s.peersLock.Unlock()
|
defer s.peersLock.Unlock()
|
||||||
|
|
||||||
dp, ok := s.peers[peer.String()]
|
dp, ok := s.peers[peer.ID()]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -43,11 +50,11 @@ func (s *Store) DeletePeer(peer *Peer) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(s.peers, peer.String())
|
delete(s.peers, peer.ID())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer returns a peer by its ID
|
// Peer returns a peer by its ID
|
||||||
func (s *Store) Peer(id string) (*Peer, bool) {
|
func (s *Store) Peer(id messages.PeerID) (IPeer, bool) {
|
||||||
s.peersLock.RLock()
|
s.peersLock.RLock()
|
||||||
defer s.peersLock.RUnlock()
|
defer s.peersLock.RUnlock()
|
||||||
|
|
||||||
@ -56,11 +63,11 @@ func (s *Store) Peer(id string) (*Peer, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Peers returns all the peers in the store
|
// Peers returns all the peers in the store
|
||||||
func (s *Store) Peers() []*Peer {
|
func (s *Store) Peers() []IPeer {
|
||||||
s.peersLock.RLock()
|
s.peersLock.RLock()
|
||||||
defer s.peersLock.RUnlock()
|
defer s.peersLock.RUnlock()
|
||||||
|
|
||||||
peers := make([]*Peer, 0, len(s.peers))
|
peers := make([]IPeer, 0, len(s.peers))
|
||||||
for _, p := range s.peers {
|
for _, p := range s.peers {
|
||||||
peers = append(peers, p)
|
peers = append(peers, p)
|
||||||
}
|
}
|
49
relay/server/store/store_test.go
Normal file
49
relay/server/store/store_test.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/relay/messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MocPeer struct {
|
||||||
|
id messages.PeerID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocPeer) Close() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocPeer) ID() messages.PeerID {
|
||||||
|
return m.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_DeletePeer(t *testing.T) {
|
||||||
|
s := NewStore()
|
||||||
|
|
||||||
|
pID := messages.HashID("peer_one")
|
||||||
|
p := &MocPeer{id: pID}
|
||||||
|
s.AddPeer(p)
|
||||||
|
s.DeletePeer(p)
|
||||||
|
if _, ok := s.Peer(pID); ok {
|
||||||
|
t.Errorf("peer was not deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
|
||||||
|
s := NewStore()
|
||||||
|
|
||||||
|
pID1 := messages.HashID("peer_one")
|
||||||
|
pID2 := messages.HashID("peer_one")
|
||||||
|
|
||||||
|
p1 := &MocPeer{id: pID1}
|
||||||
|
p2 := &MocPeer{id: pID2}
|
||||||
|
|
||||||
|
s.AddPeer(p1)
|
||||||
|
s.AddPeer(p2)
|
||||||
|
s.DeletePeer(p1)
|
||||||
|
|
||||||
|
if _, ok := s.Peer(pID2); !ok {
|
||||||
|
t.Errorf("second peer was deleted")
|
||||||
|
}
|
||||||
|
}
|
@ -1,85 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"go.opentelemetry.io/otel"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/metrics"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockConn struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockConn) Read(b []byte) (n int, err error) {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockConn) Write(b []byte) (n int, err error) {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockConn) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockConn) LocalAddr() net.Addr {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockConn) RemoteAddr() net.Addr {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockConn) SetDeadline(t time.Time) error {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockConn) SetReadDeadline(t time.Time) error {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockConn) SetWriteDeadline(t time.Time) error {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStore_DeletePeer(t *testing.T) {
|
|
||||||
s := NewStore()
|
|
||||||
|
|
||||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
|
||||||
|
|
||||||
p := NewPeer(m, []byte("peer_one"), nil, nil)
|
|
||||||
s.AddPeer(p)
|
|
||||||
s.DeletePeer(p)
|
|
||||||
if _, ok := s.Peer(p.String()); ok {
|
|
||||||
t.Errorf("peer was not deleted")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStore_DeleteDeprecatedPeer(t *testing.T) {
|
|
||||||
s := NewStore()
|
|
||||||
|
|
||||||
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
|
|
||||||
|
|
||||||
conn := &mockConn{}
|
|
||||||
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
|
|
||||||
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
|
|
||||||
|
|
||||||
s.AddPeer(p1)
|
|
||||||
s.AddPeer(p2)
|
|
||||||
s.DeletePeer(p1)
|
|
||||||
|
|
||||||
if _, ok := s.Peer(p2.String()); !ok {
|
|
||||||
t.Errorf("second peer was deleted")
|
|
||||||
}
|
|
||||||
}
|
|
33
relay/server/url.go
Normal file
33
relay/server/url.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
|
||||||
|
// provided address according to TLS definition and parses the address before returning it
|
||||||
|
func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
|
||||||
|
addr := exposedAddress
|
||||||
|
split := strings.Split(exposedAddress, "://")
|
||||||
|
switch {
|
||||||
|
case len(split) == 1 && tlsSupported:
|
||||||
|
addr = "rels://" + exposedAddress
|
||||||
|
case len(split) == 1 && !tlsSupported:
|
||||||
|
addr = "rel://" + exposedAddress
|
||||||
|
case len(split) > 2:
|
||||||
|
return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedURL, err := url.ParseRequestURI(addr)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid exposed address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
|
||||||
|
return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedURL.String(), nil
|
||||||
|
}
|
@ -12,7 +12,6 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/logging"
|
"github.com/pion/logging"
|
||||||
"github.com/pion/turn/v3"
|
"github.com/pion/turn/v3"
|
||||||
"go.opentelemetry.io/otel"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/relay/auth/allow"
|
"github.com/netbirdio/netbird/relay/auth/allow"
|
||||||
"github.com/netbirdio/netbird/relay/auth/hmac"
|
"github.com/netbirdio/netbird/relay/auth/hmac"
|
||||||
@ -22,7 +21,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
av = &allow.Auth{}
|
|
||||||
hmacTokenStore = &hmac.TokenStore{}
|
hmacTokenStore = &hmac.TokenStore{}
|
||||||
pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
|
pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
|
||||||
dataSize = 1024 * 1024 * 10
|
dataSize = 1024 * 1024 * 10
|
||||||
@ -70,8 +68,12 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
|
|||||||
port := 35000 + peerPairs
|
port := 35000 + peerPairs
|
||||||
serverAddress := fmt.Sprintf("127.0.0.1:%d", port)
|
serverAddress := fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
serverConnURL := fmt.Sprintf("rel://%s", serverAddress)
|
serverConnURL := fmt.Sprintf("rel://%s", serverAddress)
|
||||||
|
serverCfg := server.Config{
|
||||||
srv, err := server.NewServer(otel.Meter(""), serverConnURL, false, av)
|
ExposedAddress: serverConnURL,
|
||||||
|
TLSSupport: false,
|
||||||
|
AuthValidator: &allow.Auth{},
|
||||||
|
}
|
||||||
|
srv, err := server.NewServer(serverCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create server: %s", err)
|
t.Fatalf("failed to create server: %s", err)
|
||||||
}
|
}
|
||||||
@ -98,8 +100,8 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
|
|||||||
|
|
||||||
clientsSender := make([]*client.Client, peerPairs)
|
clientsSender := make([]*client.Client, peerPairs)
|
||||||
for i := 0; i < cap(clientsSender); i++ {
|
for i := 0; i < cap(clientsSender); i++ {
|
||||||
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
|
c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
|
||||||
err := c.Connect()
|
err := c.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@ -108,8 +110,8 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
|
|||||||
|
|
||||||
clientsReceiver := make([]*client.Client, peerPairs)
|
clientsReceiver := make([]*client.Client, peerPairs)
|
||||||
for i := 0; i < cap(clientsReceiver); i++ {
|
for i := 0; i < cap(clientsReceiver); i++ {
|
||||||
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
|
c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
|
||||||
err := c.Connect()
|
err := c.Connect(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to server: %s", err)
|
t.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@ -119,13 +121,13 @@ func transfer(t *testing.T, testData []byte, peerPairs int) {
|
|||||||
connsSender := make([]net.Conn, 0, peerPairs)
|
connsSender := make([]net.Conn, 0, peerPairs)
|
||||||
connsReceiver := make([]net.Conn, 0, peerPairs)
|
connsReceiver := make([]net.Conn, 0, peerPairs)
|
||||||
for i := 0; i < len(clientsSender); i++ {
|
for i := 0; i < len(clientsSender); i++ {
|
||||||
conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i))
|
conn, err := clientsSender[i].OpenConn(ctx, "receiver-"+fmt.Sprint(i))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
connsSender = append(connsSender, conn)
|
connsSender = append(connsSender, conn)
|
||||||
|
|
||||||
conn, err = clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i))
|
conn, err = clientsReceiver[i].OpenConn(ctx, "sender-"+fmt.Sprint(i))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to bind channel: %s", err)
|
t.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -70,8 +70,8 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
clientsSender := make([]*client.Client, peerPairs)
|
clientsSender := make([]*client.Client, peerPairs)
|
||||||
for i := 0; i < cap(clientsSender); i++ {
|
for i := 0; i < cap(clientsSender); i++ {
|
||||||
c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
|
c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
|
||||||
if err := c.Connect(); err != nil {
|
if err := c.Connect(ctx); err != nil {
|
||||||
log.Fatalf("failed to connect to server: %s", err)
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
clientsSender[i] = c
|
clientsSender[i] = c
|
||||||
@ -79,7 +79,7 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn {
|
|||||||
|
|
||||||
connsSender := make([]net.Conn, 0, peerPairs)
|
connsSender := make([]net.Conn, 0, peerPairs)
|
||||||
for i := 0; i < len(clientsSender); i++ {
|
for i := 0; i < len(clientsSender); i++ {
|
||||||
conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i))
|
conn, err := clientsSender[i].OpenConn(ctx, "receiver-"+fmt.Sprint(i))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to bind channel: %s", err)
|
log.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
@ -156,8 +156,8 @@ func runReader(conn net.Conn) time.Duration {
|
|||||||
func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
|
func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
|
||||||
clientsReceiver := make([]*client.Client, peerPairs)
|
clientsReceiver := make([]*client.Client, peerPairs)
|
||||||
for i := 0; i < cap(clientsReceiver); i++ {
|
for i := 0; i < cap(clientsReceiver); i++ {
|
||||||
c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
|
c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
|
||||||
err := c.Connect()
|
err := c.Connect(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to connect to server: %s", err)
|
log.Fatalf("failed to connect to server: %s", err)
|
||||||
}
|
}
|
||||||
@ -166,7 +166,7 @@ func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
|
|||||||
|
|
||||||
connsReceiver := make([]net.Conn, 0, peerPairs)
|
connsReceiver := make([]net.Conn, 0, peerPairs)
|
||||||
for i := 0; i < len(clientsReceiver); i++ {
|
for i := 0; i < len(clientsReceiver); i++ {
|
||||||
conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i))
|
conn, err := clientsReceiver[i].OpenConn(context.Background(), "sender-"+fmt.Sprint(i))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to bind channel: %s", err)
|
log.Fatalf("failed to bind channel: %s", err)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user