diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 614787e17..179ac0b75 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) type ProxyBind struct { @@ -28,6 +29,17 @@ type ProxyBind struct { pausedMu sync.Mutex paused bool isStarted bool + + closeListener *listener.CloseListener +} + +func NewProxyBind(bind *bind.ICEBind) *ProxyBind { + p := &ProxyBind{ + Bind: bind, + closeListener: listener.NewCloseListener(), + } + + return p } // AddTurnConn adds a new connection to the bind. @@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr { } } +func (p *ProxyBind) SetDisconnectListener(disconnected func()) { + p.closeListener.SetCloseListener(disconnected) +} + func (p *ProxyBind) Work() { if p.remoteConn == nil { return @@ -96,6 +112,9 @@ func (p *ProxyBind) close() error { if p.closed { return nil } + + p.closeListener.SetCloseListener(nil) + p.closed = true p.cancel() @@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { if ctx.Err() != nil { return } + p.closeListener.Notify() log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index 54cab4e1b..dbf9128a8 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -11,6 +11,8 @@ import ( "sync" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call @@ -26,6 +28,15 @@ type ProxyWrapper struct { pausedMu sync.Mutex paused bool isStarted bool + + closeListener *listener.CloseListener +} + +func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { + return &ProxyWrapper{ + WgeBPFProxy: WgeBPFProxy, + closeListener: listener.NewCloseListener(), + } } func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { @@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { return p.wgEndpointAddr } +func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { + p.closeListener.SetCloseListener(disconnected) +} + func (p *ProxyWrapper) Work() { if p.remoteConn == nil { return @@ -77,6 +92,8 @@ func (e *ProxyWrapper) CloseConn() error { e.cancel() + e.closeListener.SetCloseListener(nil) + if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { return fmt.Errorf("failed to close remote conn: %w", err) } @@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err if ctx.Err() != nil { return 0, ctx.Err() } + p.closeListener.Notify() if !errors.Is(err, io.EOF) { log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) } diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index 3ad7dc59d..e62cd97be 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy { return udpProxy.NewWGUDPProxy(w.wgPort) } - return &ebpf.ProxyWrapper{ - WgeBPFProxy: w.ebpfProxy, - } + return ebpf.NewProxyWrapper(w.ebpfProxy) + } func (w *KernelFactory) Free() error { diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index e2d479331..141b4c1f9 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { } func (w *USPFactory) GetProxy() Proxy { - return &proxyBind.ProxyBind{ - Bind: w.bind, - } + return proxyBind.NewProxyBind(w.bind) } func (w *USPFactory) Free() error { diff --git a/client/iface/wgproxy/listener/listener.go b/client/iface/wgproxy/listener/listener.go new file mode 100644 index 000000000..bfd651548 --- /dev/null +++ b/client/iface/wgproxy/listener/listener.go @@ -0,0 +1,19 @@ +package listener + +type CloseListener struct { + listener func() +} + +func NewCloseListener() *CloseListener { + return &CloseListener{} +} + +func (c *CloseListener) SetCloseListener(listener func()) { + c.listener = listener +} + +func (c *CloseListener) Notify() { + if c.listener != nil { + c.listener() + } +} diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go index 243aa2bd2..c2879877e 100644 --- a/client/iface/wgproxy/proxy.go +++ b/client/iface/wgproxy/proxy.go @@ -12,4 +12,5 @@ type Proxy interface { Work() // Work start or resume the proxy Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. CloseConn() error + SetDisconnectListener(disconnected func()) } diff --git a/client/iface/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go index 64b617621..2165b8aba 100644 --- a/client/iface/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { t.Errorf("failed to free ebpf proxy: %s", err) } }() - proxyWrapper := &ebpf.ProxyWrapper{ - WgeBPFProxy: ebpfProxy, - } + proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy) tests = append(tests, struct { name string diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index ba0004b8a..df45d8ca5 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" cerrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) // WGUDPProxy proxies @@ -28,6 +29,8 @@ type WGUDPProxy struct { pausedMu sync.Mutex paused bool isStarted bool + + closeListener *listener.CloseListener } // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation @@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy { log.Debugf("Initializing new user space proxy with port %d", wgPort) p := &WGUDPProxy{ localWGListenPort: wgPort, + closeListener: listener.NewCloseListener(), } return p } @@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr { return endpointUdpAddr } +func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) { + p.closeListener.SetCloseListener(disconnected) +} + // Work starts the proxy or resumes it if it was paused func (p *WGUDPProxy) Work() { if p.remoteConn == nil { @@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error { if p.closed { return nil } + + p.closeListener.SetCloseListener(nil) p.closed = true p.cancel() @@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) { if ctx.Err() != nil { return } + p.closeListener.Notify() log.Debugf("failed to read from wg interface conn: %s", err) return } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 1f0ec164e..7765bb51c 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -167,7 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) - conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) + conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) @@ -489,6 +489,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) return } + wgProxy.SetDisconnectListener(conn.onRelayDisconnected) + conn.dumpState.NewLocalProxy() conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index aa8f7d635..5e2900609 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -19,6 +19,7 @@ type RelayConnInfo struct { } type WorkerRelay struct { + peerCtx context.Context log *log.Entry isController bool config ConnConfig @@ -33,8 +34,9 @@ type WorkerRelay struct { wgWatcher *WGWatcher } -func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { +func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { r := &WorkerRelay{ + peerCtx: ctx, log: log, isController: ctrl, config: config, @@ -62,7 +64,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) - relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) + relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key) if err != nil { if errors.Is(err, relayClient.ErrConnAlreadyExists) { w.log.Debugf("handled offer by reusing existing relay connection") diff --git a/relay/auth/validator.go b/relay/auth/validator.go index 854efd5bb..56a20fbfe 100644 --- a/relay/auth/validator.go +++ b/relay/auth/validator.go @@ -7,13 +7,6 @@ import ( authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" ) -// Validator is an interface that defines the Validate method. -type Validator interface { - Validate(any) error - // Deprecated: Use Validate instead. - ValidateHelloMsgType(any) error -} - type TimedHMACValidator struct { authenticatorV2 *authv2.Validator authenticator *auth.TimedHMACValidator diff --git a/relay/client/client.go b/relay/client/client.go index 9e7e54393..2bf679ecb 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -124,15 +124,14 @@ func (cc *connContainer) close() { // While the Connect is in progress, the OpenConn function will block until the connection is established with relay server. type Client struct { log *log.Entry - parentCtx context.Context connectionURL string authTokenStore *auth.TokenStore - hashedID []byte + hashedID messages.PeerID bufPool *sync.Pool relayConn net.Conn - conns map[string]*connContainer + conns map[messages.PeerID]*connContainer serviceIsRunning bool mu sync.Mutex // protect serviceIsRunning and conns readLoopMutex sync.Mutex @@ -142,14 +141,17 @@ type Client struct { onDisconnectListener func(string) listenerMutex sync.Mutex + + stateSubscription *PeersStateSubscription } // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect -func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { - hashedID, hashedStringId := messages.HashID(peerID) +func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { + hashedID := messages.HashID(peerID) + relayLog := log.WithFields(log.Fields{"relay": serverURL}) + c := &Client{ - log: log.WithFields(log.Fields{"relay": serverURL}), - parentCtx: ctx, + log: relayLog, connectionURL: serverURL, authTokenStore: authTokenStore, hashedID: hashedID, @@ -159,14 +161,15 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token return &buf }, }, - conns: make(map[string]*connContainer), + conns: make(map[messages.PeerID]*connContainer), } - c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId) + + c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID) return c } // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. -func (c *Client) Connect() error { +func (c *Client) Connect(ctx context.Context) error { c.log.Infof("connecting to relay server") c.readLoopMutex.Lock() defer c.readLoopMutex.Unlock() @@ -178,17 +181,23 @@ func (c *Client) Connect() error { return nil } - if err := c.connect(); err != nil { + if err := c.connect(ctx); err != nil { return err } + c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID) + c.log = c.log.WithField("relay", c.instanceURL.String()) c.log.Infof("relay connection established") c.serviceIsRunning = true + internallyStoppedFlag := newInternalStopFlag() + hc := healthcheck.NewReceiver(c.log) + go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag) + c.wgReadLoop.Add(1) - go c.readLoop(c.relayConn) + go c.readLoop(hc, c.relayConn, internallyStoppedFlag) return nil } @@ -196,26 +205,41 @@ func (c *Client) Connect() error { // OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress // to the relay server, the function will block until the connection is established or timed out. Otherwise, // it will return immediately. +// It block until the server confirm the peer is online. // todo: what should happen if call with the same peerID with multiple times? -func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { - c.mu.Lock() - defer c.mu.Unlock() +func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) { + peerID := messages.HashID(dstPeerID) + c.mu.Lock() if !c.serviceIsRunning { + c.mu.Unlock() return nil, fmt.Errorf("relay connection is not established") } - - hashedID, hashedStringID := messages.HashID(dstPeerID) - _, ok := c.conns[hashedStringID] + _, ok := c.conns[peerID] if ok { + c.mu.Unlock() return nil, ErrConnAlreadyExists } + c.mu.Unlock() - c.log.Infof("open connection to peer: %s", hashedStringID) + if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil { + c.log.Errorf("peer not available: %s, %s", peerID, err) + return nil, err + } + + c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID) msgChannel := make(chan Msg, 100) - conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) + conn := NewConn(c, peerID, msgChannel, c.instanceURL) - c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel) + c.mu.Lock() + _, ok = c.conns[peerID] + if ok { + c.mu.Unlock() + _ = conn.Close() + return nil, ErrConnAlreadyExists + } + c.conns[peerID] = newConnContainer(c.log, conn, msgChannel) + c.mu.Unlock() return conn, nil } @@ -254,7 +278,7 @@ func (c *Client) Close() error { return c.close(true) } -func (c *Client) connect() error { +func (c *Client) connect(ctx context.Context) error { rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{}) conn, err := rd.Dial() if err != nil { @@ -262,7 +286,7 @@ func (c *Client) connect() error { } c.relayConn = conn - if err = c.handShake(); err != nil { + if err = c.handShake(ctx); err != nil { cErr := conn.Close() if cErr != nil { c.log.Errorf("failed to close connection: %s", cErr) @@ -273,7 +297,7 @@ func (c *Client) connect() error { return nil } -func (c *Client) handShake() error { +func (c *Client) handShake(ctx context.Context) error { msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { c.log.Errorf("failed to marshal auth message: %s", err) @@ -286,7 +310,7 @@ func (c *Client) handShake() error { return err } buf := make([]byte, messages.MaxHandshakeRespSize) - n, err := c.readWithTimeout(buf) + n, err := c.readWithTimeout(ctx, buf) if err != nil { c.log.Errorf("failed to read auth response: %s", err) return err @@ -319,11 +343,7 @@ func (c *Client) handShake() error { return nil } -func (c *Client) readLoop(relayConn net.Conn) { - internallyStoppedFlag := newInternalStopFlag() - hc := healthcheck.NewReceiver(c.log) - go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag) - +func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) { var ( errExit error n int @@ -370,6 +390,7 @@ func (c *Client) readLoop(relayConn net.Conn) { c.instanceURL = nil c.muInstanceURL.Unlock() + c.stateSubscription.Cleanup() c.wgReadLoop.Done() _ = c.close(false) c.notifyDisconnected() @@ -382,6 +403,14 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, c.bufPool.Put(bufPtr) case messages.MsgTypeTransport: return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) + case messages.MsgTypePeersOnline: + c.handlePeersOnlineMsg(buf) + c.bufPool.Put(bufPtr) + return true + case messages.MsgTypePeersWentOffline: + c.handlePeersWentOfflineMsg(buf) + c.bufPool.Put(bufPtr) + return true case messages.MsgTypeClose: c.log.Debugf("relay connection close by server") c.bufPool.Put(bufPtr) @@ -413,18 +442,16 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe return true } - stringID := messages.HashIDToString(peerID) - c.mu.Lock() if !c.serviceIsRunning { c.mu.Unlock() c.bufPool.Put(bufPtr) return false } - container, ok := c.conns[stringID] + container, ok := c.conns[*peerID] c.mu.Unlock() if !ok { - c.log.Errorf("peer not found: %s", stringID) + c.log.Errorf("peer not found: %s", peerID.String()) c.bufPool.Put(bufPtr) return true } @@ -437,9 +464,9 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe return true } -func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) { +func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) { c.mu.Lock() - conn, ok := c.conns[id] + conn, ok := c.conns[dstID] c.mu.Unlock() if !ok { return 0, net.ErrClosed @@ -464,7 +491,7 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [ return len(payload), err } -func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) { +func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) { for { select { case _, ok := <-hc.OnTimeout: @@ -478,7 +505,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in c.log.Warnf("failed to close connection: %s", err) } return - case <-c.parentCtx.Done(): + case <-ctx.Done(): err := c.close(true) if err != nil { c.log.Errorf("failed to teardown connection: %s", err) @@ -492,10 +519,31 @@ func (c *Client) closeAllConns() { for _, container := range c.conns { container.close() } - c.conns = make(map[string]*connContainer) + c.conns = make(map[messages.PeerID]*connContainer) } -func (c *Client) closeConn(connReference *Conn, id string) error { +func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, peerID := range peerIDs { + container, ok := c.conns[peerID] + if !ok { + c.log.Warnf("can not close connection, peer not found: %s", peerID) + continue + } + + container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID) + container.close() + delete(c.conns, peerID) + } + + if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil { + c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err) + } +} + +func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error { c.mu.Lock() defer c.mu.Unlock() @@ -507,6 +555,11 @@ func (c *Client) closeConn(connReference *Conn, id string) error { if container.conn != connReference { return fmt.Errorf("conn reference mismatch") } + + if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil { + container.log.Errorf("failed to unsubscribe from peer state change: %s", err) + } + c.log.Infof("free up connection to peer: %s", id) delete(c.conns, id) container.close() @@ -559,8 +612,8 @@ func (c *Client) writeCloseMsg() { } } -func (c *Client) readWithTimeout(buf []byte) (int, error) { - ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout) +func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) { + ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout) defer cancel() readDone := make(chan struct{}) @@ -581,3 +634,21 @@ func (c *Client) readWithTimeout(buf []byte) (int, error) { return n, err } } + +func (c *Client) handlePeersOnlineMsg(buf []byte) { + peersID, err := messages.UnmarshalPeersOnlineMsg(buf) + if err != nil { + c.log.Errorf("failed to unmarshal peers online msg: %s", err) + return + } + c.stateSubscription.OnPeersOnline(peersID) +} + +func (c *Client) handlePeersWentOfflineMsg(buf []byte) { + peersID, err := messages.UnMarshalPeersWentOffline(buf) + if err != nil { + c.log.Errorf("failed to unmarshal peers went offline msg: %s", err) + return + } + c.stateSubscription.OnPeersWentOffline(peersID) +} diff --git a/relay/client/client_test.go b/relay/client/client_test.go index 7ddfba4c6..dd5f5fe1e 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -18,14 +18,19 @@ import ( ) var ( - av = &allow.Auth{} hmacTokenStore = &hmac.TokenStore{} serverListenAddr = "127.0.0.1:1234" serverURL = "rel://127.0.0.1:1234" + serverCfg = server.Config{ + Meter: otel.Meter(""), + ExposedAddress: serverURL, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } ) func TestMain(m *testing.M) { - _ = util.InitLog("error", "console") + _ = util.InitLog("debug", "console") code := m.Run() os.Exit(code) } @@ -33,7 +38,7 @@ func TestMain(m *testing.M) { func TestClient(t *testing.T) { ctx := context.Background() - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -58,37 +63,37 @@ func TestClient(t *testing.T) { t.Fatalf("failed to start server: %s", err) } t.Log("alice connecting to server") - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientAlice.Close() t.Log("placeholder connecting to server") - clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder") - err = clientPlaceHolder.Connect() + clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder") + err = clientPlaceHolder.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientPlaceHolder.Close() t.Log("Bob connecting to server") - clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") - err = clientBob.Connect() + clientBob := NewClient(serverURL, hmacTokenStore, "bob") + err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } defer clientBob.Close() t.Log("Alice open connection to Bob") - connAliceToBob, err := clientAlice.OpenConn("bob") + connAliceToBob, err := clientAlice.OpenConn(ctx, "bob") if err != nil { t.Fatalf("failed to bind channel: %s", err) } t.Log("Bob open connection to Alice") - connBobToAlice, err := clientBob.OpenConn("alice") + connBobToAlice, err := clientBob.OpenConn(ctx, "alice") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -115,7 +120,7 @@ func TestClient(t *testing.T) { func TestRegistration(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -132,8 +137,8 @@ func TestRegistration(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) if err != nil { _ = srv.Shutdown(ctx) t.Fatalf("failed to connect to server: %s", err) @@ -172,8 +177,8 @@ func TestRegistrationTimeout(t *testing.T) { _ = fakeTCPListener.Close() }(fakeTCPListener) - clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) if err == nil { t.Errorf("failed to connect to server: %s", err) } @@ -189,7 +194,7 @@ func TestEcho(t *testing.T) { idAlice := "alice" idBob := "bob" srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -213,8 +218,8 @@ func TestEcho(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, idAlice) + err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -225,8 +230,8 @@ func TestEcho(t *testing.T) { } }() - clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) - err = clientBob.Connect() + clientBob := NewClient(serverURL, hmacTokenStore, idBob) + err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -237,12 +242,12 @@ func TestEcho(t *testing.T) { } }() - connAliceToBob, err := clientAlice.OpenConn(idBob) + connAliceToBob, err := clientAlice.OpenConn(ctx, idBob) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(idAlice) + connBobToAlice, err := clientBob.OpenConn(ctx, idAlice) if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -278,7 +283,7 @@ func TestBindToUnavailabePeer(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -303,14 +308,14 @@ func TestBindToUnavailabePeer(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - _, err = clientAlice.OpenConn("bob") - if err != nil { - t.Errorf("failed to bind channel: %s", err) + _, err = clientAlice.OpenConn(ctx, "bob") + if err == nil { + t.Errorf("expected error when binding to unavailable peer, got nil") } log.Infof("closing client") @@ -324,7 +329,7 @@ func TestBindReconnect(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -349,24 +354,24 @@ func TestBindReconnect(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + + clientBob := NewClient(serverURL, hmacTokenStore, "bob") + err = clientBob.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - _, err = clientAlice.OpenConn("bob") + _, err = clientAlice.OpenConn(ctx, "bob") if err != nil { - t.Errorf("failed to bind channel: %s", err) + t.Fatalf("failed to bind channel: %s", err) } - clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") - err = clientBob.Connect() - if err != nil { - t.Errorf("failed to connect to server: %s", err) - } - - chBob, err := clientBob.OpenConn("alice") + chBob, err := clientBob.OpenConn(ctx, "alice") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -377,18 +382,28 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to close client: %s", err) } - clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + clientAlice = NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - chAlice, err := clientAlice.OpenConn("bob") + chAlice, err := clientAlice.OpenConn(ctx, "bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } testString := "hello alice, I am bob" + _, err = chBob.Write([]byte(testString)) + if err == nil { + t.Errorf("expected error when writing to channel, got nil") + } + + chBob, err = clientBob.OpenConn(ctx, "alice") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + _, err = chBob.Write([]byte(testString)) if err != nil { t.Errorf("failed to write to channel: %s", err) @@ -415,7 +430,7 @@ func TestCloseConn(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -440,13 +455,19 @@ func TestCloseConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + bob := NewClient(serverURL, hmacTokenStore, "bob") + err = bob.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - conn, err := clientAlice.OpenConn("bob") + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + conn, err := clientAlice.OpenConn(ctx, "bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -472,7 +493,7 @@ func TestCloseRelayConn(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -496,13 +517,19 @@ func TestCloseRelayConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") - err = clientAlice.Connect() + bob := NewClient(serverURL, hmacTokenStore, "bob") + err = bob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } - conn, err := clientAlice.OpenConn("bob") + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect(ctx) + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + + conn, err := clientAlice.OpenConn(ctx, "bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -514,7 +541,7 @@ func TestCloseRelayConn(t *testing.T) { t.Errorf("unexpected reading from closed connection") } - _, err = clientAlice.OpenConn("bob") + _, err = clientAlice.OpenConn(ctx, "bob") if err == nil { t.Errorf("unexpected opening connection to closed server") } @@ -524,7 +551,7 @@ func TestCloseByServer(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv1, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -544,8 +571,8 @@ func TestCloseByServer(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) - err = relayClient.Connect() + relayClient := NewClient(serverURL, hmacTokenStore, idAlice) + err = relayClient.Connect(ctx) if err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -567,7 +594,7 @@ func TestCloseByServer(t *testing.T) { log.Fatalf("timeout waiting for client to disconnect") } - _, err = relayClient.OpenConn("bob") + _, err = relayClient.OpenConn(ctx, "bob") if err == nil { t.Errorf("unexpected opening connection to closed server") } @@ -577,7 +604,7 @@ func TestCloseByClient(t *testing.T) { ctx := context.Background() srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -596,8 +623,8 @@ func TestCloseByClient(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) - err = relayClient.Connect() + relayClient := NewClient(serverURL, hmacTokenStore, idAlice) + err = relayClient.Connect(ctx) if err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -607,7 +634,7 @@ func TestCloseByClient(t *testing.T) { t.Errorf("failed to close client: %s", err) } - _, err = relayClient.OpenConn("bob") + _, err = relayClient.OpenConn(ctx, "bob") if err == nil { t.Errorf("unexpected opening connection to closed server") } @@ -623,7 +650,7 @@ func TestCloseNotDrainedChannel(t *testing.T) { idAlice := "alice" idBob := "bob" srvCfg := server.ListenerConfig{Address: serverListenAddr} - srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -647,8 +674,8 @@ func TestCloseNotDrainedChannel(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) - err = clientAlice.Connect() + clientAlice := NewClient(serverURL, hmacTokenStore, idAlice) + err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -659,8 +686,8 @@ func TestCloseNotDrainedChannel(t *testing.T) { } }() - clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) - err = clientBob.Connect() + clientBob := NewClient(serverURL, hmacTokenStore, idBob) + err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -671,12 +698,12 @@ func TestCloseNotDrainedChannel(t *testing.T) { } }() - connAliceToBob, err := clientAlice.OpenConn(idBob) + connAliceToBob, err := clientAlice.OpenConn(ctx, idBob) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(idAlice) + connBobToAlice, err := clientBob.OpenConn(ctx, idAlice) if err != nil { t.Fatalf("failed to bind channel: %s", err) } diff --git a/relay/client/conn.go b/relay/client/conn.go index fe1b6fb52..d8cffa695 100644 --- a/relay/client/conn.go +++ b/relay/client/conn.go @@ -3,13 +3,14 @@ package client import ( "net" "time" + + "github.com/netbirdio/netbird/relay/messages" ) // Conn represent a connection to a relayed remote peer. type Conn struct { client *Client - dstID []byte - dstStringID string + dstID messages.PeerID messageChan chan Msg instanceURL *RelayAddr } @@ -17,14 +18,12 @@ type Conn struct { // NewConn creates a new connection to a relayed remote peer. // client: the client instance, it used to send messages to the destination peer // dstID: the destination peer ID -// dstStringID: the destination peer ID in string format // messageChan: the channel where the messages will be received // instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer -func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn { +func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn { c := &Conn{ client: client, dstID: dstID, - dstStringID: dstStringID, messageChan: messageChan, instanceURL: instanceURL, } @@ -33,7 +32,7 @@ func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan } func (c *Conn) Write(p []byte) (n int, err error) { - return c.client.writeTo(c, c.dstStringID, c.dstID, p) + return c.client.writeTo(c, c.dstID, p) } func (c *Conn) Read(b []byte) (n int, err error) { @@ -48,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { } func (c *Conn) Close() error { - return c.client.closeConn(c, c.dstStringID) + return c.client.closeConn(c, c.dstID) } func (c *Conn) LocalAddr() net.Addr { diff --git a/relay/client/guard.go b/relay/client/guard.go index 554330ea3..100892d81 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -80,7 +80,7 @@ func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) - if err := rc.Connect(); err != nil { + if err := rc.Connect(parentCtx); err != nil { log.Errorf("failed to reconnect to relay server: %s", err) return false } diff --git a/relay/client/manager.go b/relay/client/manager.go index 26b113050..0fb682d95 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -42,7 +42,7 @@ type OnServerCloseListener func() // ManagerService is the interface for the relay manager. type ManagerService interface { Serve() error - OpenConn(serverAddress, peerKey string) (net.Conn, error) + OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error RelayInstanceAddress() (string, error) ServerURLs() []string @@ -123,7 +123,7 @@ func (m *Manager) Serve() error { // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. -func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { +func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { m.relayClientMu.Lock() defer m.relayClientMu.Unlock() @@ -141,10 +141,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { ) if !foreign { log.Debugf("open peer connection via permanent server: %s", peerKey) - netConn, err = m.relayClient.OpenConn(peerKey) + netConn, err = m.relayClient.OpenConn(ctx, peerKey) } else { log.Debugf("open peer connection via foreign server: %s", serverAddress) - netConn, err = m.openConnVia(serverAddress, peerKey) + netConn, err = m.openConnVia(ctx, serverAddress, peerKey) } if err != nil { return nil, err @@ -229,7 +229,7 @@ func (m *Manager) UpdateToken(token *relayAuth.Token) error { return m.tokenStore.UpdateToken(token) } -func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { +func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) { // check if already has a connection to the desired relay server m.relayClientsMutex.RLock() rt, ok := m.relayClients[serverAddress] @@ -240,7 +240,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { if rt.err != nil { return nil, rt.err } - return rt.relayClient.OpenConn(peerKey) + return rt.relayClient.OpenConn(ctx, peerKey) } m.relayClientsMutex.RUnlock() @@ -255,7 +255,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { if rt.err != nil { return nil, rt.err } - return rt.relayClient.OpenConn(peerKey) + return rt.relayClient.OpenConn(ctx, peerKey) } // create a new relay client and store it in the relayClients map @@ -264,8 +264,8 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { m.relayClients[serverAddress] = rt m.relayClientsMutex.Unlock() - relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID) - err := relayClient.Connect() + relayClient := NewClient(serverAddress, m.tokenStore, m.peerID) + err := relayClient.Connect(m.ctx) if err != nil { rt.err = err rt.Unlock() @@ -279,7 +279,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { rt.relayClient = relayClient rt.Unlock() - conn, err := relayClient.OpenConn(peerKey) + conn, err := relayClient.OpenConn(ctx, peerKey) if err != nil { return nil, err } diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index bfc342f25..d20cdaac0 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel" + "github.com/netbirdio/netbird/relay/auth/allow" "github.com/netbirdio/netbird/relay/server" ) @@ -22,16 +23,22 @@ func TestEmptyURL(t *testing.T) { func TestForeignConn(t *testing.T) { ctx := context.Background() - srvCfg1 := server.ListenerConfig{ + lstCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + + srv1, err := server.NewServer(server.Config{ + Meter: otel.Meter(""), + ExposedAddress: lstCfg1.Address, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + }) if err != nil { t.Fatalf("failed to create server: %s", err) } errChan := make(chan error, 1) go func() { - err := srv1.Listen(srvCfg1) + err := srv1.Listen(lstCfg1) if err != nil { errChan <- err } @@ -51,7 +58,12 @@ func TestForeignConn(t *testing.T) { srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) + srv2, err := server.NewServer(server.Config{ + Meter: otel.Meter(""), + ExposedAddress: srvCfg2.Address, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + }) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -74,32 +86,26 @@ func TestForeignConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - idAlice := "alice" - log.Debugf("connect by alice") mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) - err = clientAlice.Serve() - if err != nil { + clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice") + if err := clientAlice.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - idBob := "bob" - log.Debugf("connect by bob") - clientBob := NewManager(mCtx, toURL(srvCfg2), idBob) - err = clientBob.Serve() - if err != nil { + clientBob := NewManager(mCtx, toURL(srvCfg2), "bob") + if err := clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } bobsSrvAddr, err := clientBob.RelayInstanceAddress() if err != nil { t.Fatalf("failed to get relay address: %s", err) } - connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob) + connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob") if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice) + connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -137,7 +143,7 @@ func TestForeginConnClose(t *testing.T) { srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + srv1, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -163,7 +169,7 @@ func TestForeginConnClose(t *testing.T) { srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) + srv2, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -186,16 +192,20 @@ func TestForeginConnClose(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - idAlice := "alice" - log.Debugf("connect by alice") mCtx, cancel := context.WithCancel(ctx) defer cancel() - mgr := NewManager(mCtx, toURL(srvCfg1), idAlice) + + mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob") + if err := mgrBob.Serve(); err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + + mgr := NewManager(mCtx, toURL(srvCfg1), "alice") err = mgr.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) } - conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") + conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -212,7 +222,7 @@ func TestForeginAutoClose(t *testing.T) { srvCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + srv1, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -241,7 +251,7 @@ func TestForeginAutoClose(t *testing.T) { srvCfg2 := server.ListenerConfig{ Address: "localhost:2234", } - srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) + srv2, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -277,7 +287,7 @@ func TestForeginAutoClose(t *testing.T) { } t.Log("open connection to another peer") - conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") + conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer") if err != nil { t.Fatalf("failed to bind channel: %s", err) } @@ -305,7 +315,7 @@ func TestAutoReconnect(t *testing.T) { srvCfg := server.ListenerConfig{ Address: "localhost:1234", } - srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av) + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -330,6 +340,13 @@ func TestAutoReconnect(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() + + clientBob := NewManager(mCtx, toURL(srvCfg), "bob") + err = clientBob.Serve() + if err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + clientAlice := NewManager(mCtx, toURL(srvCfg), "alice") err = clientAlice.Serve() if err != nil { @@ -339,7 +356,7 @@ func TestAutoReconnect(t *testing.T) { if err != nil { t.Errorf("failed to get relay address: %s", err) } - conn, err := clientAlice.OpenConn(ra, "bob") + conn, err := clientAlice.OpenConn(ctx, ra, "bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } @@ -357,7 +374,7 @@ func TestAutoReconnect(t *testing.T) { time.Sleep(reconnectingTimeout + 1*time.Second) log.Infof("reopent the connection") - _, err = clientAlice.OpenConn(ra, "bob") + _, err = clientAlice.OpenConn(ctx, ra, "bob") if err != nil { t.Errorf("failed to open channel: %s", err) } @@ -366,24 +383,27 @@ func TestAutoReconnect(t *testing.T) { func TestNotifierDoubleAdd(t *testing.T) { ctx := context.Background() - srvCfg1 := server.ListenerConfig{ + listenerCfg1 := server.ListenerConfig{ Address: "localhost:1234", } - srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + srv, err := server.NewServer(server.Config{ + Meter: otel.Meter(""), + ExposedAddress: listenerCfg1.Address, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + }) if err != nil { t.Fatalf("failed to create server: %s", err) } errChan := make(chan error, 1) go func() { - err := srv1.Listen(srvCfg1) - if err != nil { + if err := srv.Listen(listenerCfg1); err != nil { errChan <- err } }() defer func() { - err := srv1.Shutdown(ctx) - if err != nil { + if err := srv.Shutdown(ctx); err != nil { t.Errorf("failed to close server: %s", err) } }() @@ -392,17 +412,21 @@ func TestNotifierDoubleAdd(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - idAlice := "alice" log.Debugf("connect by alice") mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) - err = clientAlice.Serve() - if err != nil { + + clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob") + if err = clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob") + clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice") + if err = clientAlice.Serve(); err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + + conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob") if err != nil { t.Fatalf("failed to bind channel: %s", err) } diff --git a/relay/client/peer_subscription.go b/relay/client/peer_subscription.go new file mode 100644 index 000000000..03e7127b3 --- /dev/null +++ b/relay/client/peer_subscription.go @@ -0,0 +1,168 @@ +package client + +import ( + "context" + "errors" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/messages" +) + +const ( + OpenConnectionTimeout = 30 * time.Second +) + +type relayedConnWriter interface { + Write(p []byte) (n int, err error) +} + +// PeersStateSubscription manages subscriptions to peer state changes (online/offline) +// over a relay connection. It allows tracking peers' availability and handling offline +// events via a callback. We get online notification from the server only once. +type PeersStateSubscription struct { + log *log.Entry + relayConn relayedConnWriter + offlineCallback func(peerIDs []messages.PeerID) + + listenForOfflinePeers map[messages.PeerID]struct{} + waitingPeers map[messages.PeerID]chan struct{} +} + +func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription { + return &PeersStateSubscription{ + log: log, + relayConn: relayConn, + offlineCallback: offlineCallback, + listenForOfflinePeers: make(map[messages.PeerID]struct{}), + waitingPeers: make(map[messages.PeerID]chan struct{}), + } +} + +// OnPeersOnline should be called when a notification is received that certain peers have come online. +// It checks if any of the peers are being waited on and signals their availability. +func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) { + for _, peerID := range peersID { + waitCh, ok := s.waitingPeers[peerID] + if !ok { + continue + } + + close(waitCh) + delete(s.waitingPeers, peerID) + } +} + +func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) { + relevantPeers := make([]messages.PeerID, 0, len(peersID)) + for _, peerID := range peersID { + if _, ok := s.listenForOfflinePeers[peerID]; ok { + relevantPeers = append(relevantPeers, peerID) + } + } + + if len(relevantPeers) > 0 { + s.offlineCallback(relevantPeers) + } +} + +// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes. +// todo: when we unsubscribe while this is running, this will not return with error +func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error { + // Check if already waiting for this peer + if _, exists := s.waitingPeers[peerID]; exists { + return errors.New("already waiting for peer to come online") + } + + // Create a channel to wait for the peer to come online + waitCh := make(chan struct{}) + s.waitingPeers[peerID] = waitCh + + if err := s.subscribeStateChange([]messages.PeerID{peerID}); err != nil { + s.log.Errorf("failed to subscribe to peer state: %s", err) + close(waitCh) + delete(s.waitingPeers, peerID) + return err + } + + defer func() { + if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh { + close(waitCh) + delete(s.waitingPeers, peerID) + } + }() + + // Wait for peer to come online or context to be cancelled + timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout) + defer cancel() + select { + case <-waitCh: + s.log.Debugf("peer %s is now online", peerID) + return nil + case <-timeoutCtx.Done(): + s.log.Debugf("context timed out while waiting for peer %s to come online", peerID) + if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil { + s.log.Errorf("failed to unsubscribe from peer state: %s", err) + } + return timeoutCtx.Err() + } +} + +func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error { + msgErr := s.unsubscribeStateChange(peerIDs) + + for _, peerID := range peerIDs { + if wch, ok := s.waitingPeers[peerID]; ok { + close(wch) + delete(s.waitingPeers, peerID) + } + + delete(s.listenForOfflinePeers, peerID) + } + + return msgErr +} + +func (s *PeersStateSubscription) Cleanup() { + for _, waitCh := range s.waitingPeers { + close(waitCh) + } + + s.waitingPeers = make(map[messages.PeerID]chan struct{}) + s.listenForOfflinePeers = make(map[messages.PeerID]struct{}) +} + +func (s *PeersStateSubscription) subscribeStateChange(peerIDs []messages.PeerID) error { + msgs, err := messages.MarshalSubPeerStateMsg(peerIDs) + if err != nil { + return err + } + + for _, peer := range peerIDs { + s.listenForOfflinePeers[peer] = struct{}{} + } + + for _, msg := range msgs { + if _, err := s.relayConn.Write(msg); err != nil { + return err + } + + } + return nil +} + +func (s *PeersStateSubscription) unsubscribeStateChange(peerIDs []messages.PeerID) error { + msgs, err := messages.MarshalUnsubPeerStateMsg(peerIDs) + if err != nil { + return err + } + + var connWriteErr error + for _, msg := range msgs { + if _, err := s.relayConn.Write(msg); err != nil { + connWriteErr = err + } + } + return connWriteErr +} diff --git a/relay/client/peer_subscription_test.go b/relay/client/peer_subscription_test.go new file mode 100644 index 000000000..0437efa04 --- /dev/null +++ b/relay/client/peer_subscription_test.go @@ -0,0 +1,99 @@ +package client + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/netbirdio/netbird/relay/messages" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockRelayedConn struct { +} + +func (m *mockRelayedConn) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) { + peerID := messages.HashID("peer1") + mockConn := &mockRelayedConn{} + logger := logrus.New() + logger.SetOutput(&bytes.Buffer{}) // discard log output + sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Launch wait in background + go func() { + time.Sleep(100 * time.Millisecond) + sub.OnPeersOnline([]messages.PeerID{peerID}) + }() + + err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID) + assert.NoError(t, err) +} + +func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) { + peerID := messages.HashID("peer2") + mockConn := &mockRelayedConn{} + logger := logrus.New() + logger.SetOutput(&bytes.Buffer{}) + sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID) + assert.Error(t, err) + assert.Equal(t, context.DeadlineExceeded, err) +} + +func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) { + peerID := messages.HashID("peer3") + mockConn := &mockRelayedConn{} + logger := logrus.New() + logger.SetOutput(&bytes.Buffer{}) + sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil) + + ctx := context.Background() + go func() { + _ = sub.WaitToBeOnlineAndSubscribe(ctx, peerID) + + }() + time.Sleep(100 * time.Millisecond) + err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID) + require.Error(t, err) + assert.Contains(t, err.Error(), "already waiting") +} + +func TestUnsubscribeStateChange(t *testing.T) { + peerID := messages.HashID("peer4") + mockConn := &mockRelayedConn{} + logger := logrus.New() + logger.SetOutput(&bytes.Buffer{}) + sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil) + + doneChan := make(chan struct{}) + go func() { + _ = sub.WaitToBeOnlineAndSubscribe(context.Background(), peerID) + close(doneChan) + }() + time.Sleep(100 * time.Millisecond) + + err := sub.UnsubscribeStateChange([]messages.PeerID{peerID}) + assert.NoError(t, err) + + select { + case <-doneChan: + case <-time.After(200 * time.Millisecond): + // Expected timeout, meaning the subscription was successfully unsubscribed + t.Errorf("timeout") + } +} diff --git a/relay/client/picker.go b/relay/client/picker.go index eb5062dbb..9565425a8 100644 --- a/relay/client/picker.go +++ b/relay/client/picker.go @@ -70,8 +70,8 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { log.Infof("try to connecting to relay server: %s", url) - relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID) - err := relayClient.Connect() + relayClient := NewClient(url, sp.TokenStore, sp.PeerID) + err := relayClient.Connect(ctx) resultChan <- connResult{ RelayClient: relayClient, Url: url, diff --git a/relay/cmd/root.go b/relay/cmd/root.go index d603ff73b..15090024c 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -141,7 +141,14 @@ func execute(cmd *cobra.Command, args []string) error { hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) - srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) + cfg := server.Config{ + Meter: metricsServer.Meter, + ExposedAddress: cobraConfig.ExposedAddress, + AuthValidator: authenticator, + TLSSupport: tlsSupport, + } + + srv, err := server.NewServer(cfg) if err != nil { log.Debugf("failed to create relay server: %v", err) return fmt.Errorf("failed to create relay server: %v", err) diff --git a/relay/messages/id.go b/relay/messages/id.go index e2162cd3b..96ace3478 100644 --- a/relay/messages/id.go +++ b/relay/messages/id.go @@ -8,24 +8,24 @@ import ( const ( prefixLength = 4 - IDSize = prefixLength + sha256.Size + peerIDSize = prefixLength + sha256.Size ) var ( prefix = []byte("sha-") // 4 bytes ) -// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string -func HashID(peerID string) ([]byte, string) { - idHash := sha256.Sum256([]byte(peerID)) - idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:]) - var prefixedHash []byte - prefixedHash = append(prefixedHash, prefix...) - prefixedHash = append(prefixedHash, idHash[:]...) - return prefixedHash, idHashString +type PeerID [peerIDSize]byte + +func (p PeerID) String() string { + return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:])) } -// HashIDToString converts a hash to a human-readable string -func HashIDToString(idHash []byte) string { - return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:])) +// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string +func HashID(peerID string) PeerID { + idHash := sha256.Sum256([]byte(peerID)) + var prefixedHash [peerIDSize]byte + copy(prefixedHash[:prefixLength], prefix) + copy(prefixedHash[prefixLength:], idHash[:]) + return prefixedHash } diff --git a/relay/messages/id_test.go b/relay/messages/id_test.go deleted file mode 100644 index 271a8f90d..000000000 --- a/relay/messages/id_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package messages - -import ( - "testing" -) - -func TestHashID(t *testing.T) { - hashedID, hashedStringId := HashID("alice") - enc := HashIDToString(hashedID) - if enc != hashedStringId { - t.Errorf("expected %s, got %s", hashedStringId, enc) - } -} diff --git a/relay/messages/message.go b/relay/messages/message.go index 7794c57bc..54671f5df 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -9,19 +9,26 @@ import ( const ( MaxHandshakeSize = 212 MaxHandshakeRespSize = 8192 + MaxMessageSize = 8820 CurrentProtocolVersion = 1 MsgTypeUnknown MsgType = 0 // Deprecated: Use MsgTypeAuth instead. - MsgTypeHello MsgType = 1 + MsgTypeHello = 1 // Deprecated: Use MsgTypeAuthResponse instead. - MsgTypeHelloResponse MsgType = 2 - MsgTypeTransport MsgType = 3 - MsgTypeClose MsgType = 4 - MsgTypeHealthCheck MsgType = 5 - MsgTypeAuth = 6 - MsgTypeAuthResponse = 7 + MsgTypeHelloResponse = 2 + MsgTypeTransport = 3 + MsgTypeClose = 4 + MsgTypeHealthCheck = 5 + MsgTypeAuth = 6 + MsgTypeAuthResponse = 7 + + // Peers state messages + MsgTypeSubscribePeerState = 8 + MsgTypeUnsubscribePeerState = 9 + MsgTypePeersOnline = 10 + MsgTypePeersWentOffline = 11 // base size of the message sizeOfVersionByte = 1 @@ -30,17 +37,17 @@ const ( // auth message sizeOfMagicByte = 4 - headerSizeAuth = sizeOfMagicByte + IDSize + headerSizeAuth = sizeOfMagicByte + peerIDSize offsetMagicByte = sizeOfProtoHeader offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth // hello message - headerSizeHello = sizeOfMagicByte + IDSize + headerSizeHello = sizeOfMagicByte + peerIDSize headerSizeHelloResp = 0 // transport - headerSizeTransport = IDSize + headerSizeTransport = peerIDSize offsetTransportID = sizeOfProtoHeader headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport ) @@ -72,6 +79,14 @@ func (m MsgType) String() string { return "close" case MsgTypeHealthCheck: return "health check" + case MsgTypeSubscribePeerState: + return "subscribe peer state" + case MsgTypeUnsubscribePeerState: + return "unsubscribe peer state" + case MsgTypePeersOnline: + return "peers online" + case MsgTypePeersWentOffline: + return "peers went offline" default: return "unknown" } @@ -102,7 +117,9 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) { MsgTypeAuth, MsgTypeTransport, MsgTypeClose, - MsgTypeHealthCheck: + MsgTypeHealthCheck, + MsgTypeSubscribePeerState, + MsgTypeUnsubscribePeerState: return msgType, nil default: return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) @@ -122,7 +139,9 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) { MsgTypeAuthResponse, MsgTypeTransport, MsgTypeClose, - MsgTypeHealthCheck: + MsgTypeHealthCheck, + MsgTypePeersOnline, + MsgTypePeersWentOffline: return msgType, nil default: return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) @@ -135,11 +154,7 @@ func DetermineServerMessageType(msg []byte) (MsgType, error) { // message is used to authenticate the client with the server. The authentication is done using an HMAC method. // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // close the network connection without any response. -func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { - if len(peerID) != IDSize { - return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) - } - +func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) { msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions)) msg[0] = byte(CurrentProtocolVersion) @@ -147,7 +162,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) - msg = append(msg, peerID...) + msg = append(msg, peerID[:]...) msg = append(msg, additions...) return msg, nil @@ -156,7 +171,7 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { // Deprecated: Use UnmarshalAuthMsg instead. // UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to // authenticate the client with the server. -func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { +func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) { if len(msg) < sizeOfProtoHeader+headerSizeHello { return nil, nil, ErrInvalidMessageLength } @@ -164,7 +179,9 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { return nil, nil, errors.New("invalid magic header") } - return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil + peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello]) + + return &peerID, msg[headerSizeHello:], nil } // Deprecated: Use MarshalAuthResponse instead. @@ -197,34 +214,33 @@ func UnmarshalHelloResponse(msg []byte) ([]byte, error) { // message is used to authenticate the client with the server. The authentication is done using an HMAC method. // The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will // close the network connection without any response. -func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { - if len(peerID) != IDSize { - return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) +func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) { + if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize { + return nil, fmt.Errorf("too large auth payload") } - msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload)) - + msg := make([]byte, headerTotalSizeAuth+len(authPayload)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeAuth) - copy(msg[sizeOfProtoHeader:], magicHeader) - - msg = append(msg, peerID...) - msg = append(msg, authPayload...) - + copy(msg[offsetAuthPeerID:], peerID[:]) + copy(msg[headerTotalSizeAuth:], authPayload) return msg, nil } // UnmarshalAuthMsg extracts peerID and the auth payload from the message -func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { +func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) { if len(msg) < headerTotalSizeAuth { return nil, nil, ErrInvalidMessageLength } + + // Validate the magic header if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) { return nil, nil, errors.New("invalid magic header") } - return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil + peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth]) + return &peerID, msg[headerTotalSizeAuth:], nil } // MarshalAuthResponse creates a response message to the auth. @@ -268,45 +284,48 @@ func MarshalCloseMsg() []byte { // MarshalTransportMsg creates a transport message. // The transport message is used to exchange data between peers. The message contains the data to be exchanged and the // destination peer hashed ID. -func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) { - if len(peerID) != IDSize { - return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) - } - - msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload)) +func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) { + // todo validate size + msg := make([]byte, headerTotalSizeTransport+len(payload)) msg[0] = byte(CurrentProtocolVersion) msg[1] = byte(MsgTypeTransport) - copy(msg[sizeOfProtoHeader:], peerID) - msg = append(msg, payload...) - + copy(msg[sizeOfProtoHeader:], peerID[:]) + copy(msg[sizeOfProtoHeader+peerIDSize:], payload) return msg, nil } // UnmarshalTransportMsg extracts the peerID and the payload from the transport message. -func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { +func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) { if len(buf) < headerTotalSizeTransport { return nil, nil, ErrInvalidMessageLength } - return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil + const offsetEnd = offsetTransportID + peerIDSize + var peerID PeerID + copy(peerID[:], buf[offsetTransportID:offsetEnd]) + return &peerID, buf[headerTotalSizeTransport:], nil } // UnmarshalTransportID extracts the peerID from the transport message. -func UnmarshalTransportID(buf []byte) ([]byte, error) { +func UnmarshalTransportID(buf []byte) (*PeerID, error) { if len(buf) < headerTotalSizeTransport { return nil, ErrInvalidMessageLength } - return buf[offsetTransportID:headerTotalSizeTransport], nil + + const offsetEnd = offsetTransportID + peerIDSize + var id PeerID + copy(id[:], buf[offsetTransportID:offsetEnd]) + return &id, nil } // UpdateTransportMsg updates the peerID in the transport message. // With this function the server can reuse the given byte slice to update the peerID in the transport message. So do // need to allocate a new byte slice. -func UpdateTransportMsg(msg []byte, peerID []byte) error { - if len(msg) < offsetTransportID+len(peerID) { +func UpdateTransportMsg(msg []byte, peerID PeerID) error { + if len(msg) < offsetTransportID+peerIDSize { return ErrInvalidMessageLength } - copy(msg[offsetTransportID:], peerID) + copy(msg[offsetTransportID:], peerID[:]) return nil } diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go index 19bede07b..59a89cad1 100644 --- a/relay/messages/message_test.go +++ b/relay/messages/message_test.go @@ -5,7 +5,7 @@ import ( ) func TestMarshalHelloMsg(t *testing.T) { - peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") msg, err := MarshalHelloMsg(peerID, nil) if err != nil { t.Fatalf("error: %v", err) @@ -24,13 +24,13 @@ func TestMarshalHelloMsg(t *testing.T) { if err != nil { t.Fatalf("error: %v", err) } - if string(receivedPeerID) != string(peerID) { + if receivedPeerID.String() != peerID.String() { t.Errorf("expected %s, got %s", peerID, receivedPeerID) } } func TestMarshalAuthMsg(t *testing.T) { - peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") msg, err := MarshalAuthMsg(peerID, []byte{}) if err != nil { t.Fatalf("error: %v", err) @@ -49,7 +49,7 @@ func TestMarshalAuthMsg(t *testing.T) { if err != nil { t.Fatalf("error: %v", err) } - if string(receivedPeerID) != string(peerID) { + if receivedPeerID.String() != peerID.String() { t.Errorf("expected %s, got %s", peerID, receivedPeerID) } } @@ -80,7 +80,7 @@ func TestMarshalAuthResponse(t *testing.T) { } func TestMarshalTransportMsg(t *testing.T) { - peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") payload := []byte("payload") msg, err := MarshalTransportMsg(peerID, payload) if err != nil { @@ -101,7 +101,7 @@ func TestMarshalTransportMsg(t *testing.T) { t.Fatalf("failed to unmarshal transport id: %v", err) } - if string(uPeerID) != string(peerID) { + if uPeerID.String() != peerID.String() { t.Errorf("expected %s, got %s", peerID, uPeerID) } @@ -110,8 +110,8 @@ func TestMarshalTransportMsg(t *testing.T) { t.Fatalf("error: %v", err) } - if string(id) != string(peerID) { - t.Errorf("expected %s, got %s", peerID, id) + if id.String() != peerID.String() { + t.Errorf("expected: '%s', got: '%s'", peerID, id) } if string(respPayload) != string(payload) { diff --git a/relay/messages/peer_state.go b/relay/messages/peer_state.go new file mode 100644 index 000000000..f10bc7bdf --- /dev/null +++ b/relay/messages/peer_state.go @@ -0,0 +1,92 @@ +package messages + +import ( + "fmt" +) + +func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) { + return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState)) +} + +func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) { + return unmarshalPeerIDs(buf) +} + +func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) { + return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState)) +} + +func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) { + return unmarshalPeerIDs(buf) +} + +func MarshalPeersOnline(ids []PeerID) ([][]byte, error) { + return marshalPeerIDs(ids, byte(MsgTypePeersOnline)) +} + +func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) { + return unmarshalPeerIDs(buf) +} + +func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) { + return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline)) +} + +func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) { + return unmarshalPeerIDs(buf) +} + +// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type +func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) { + if len(ids) == 0 { + return nil, fmt.Errorf("no list of peer ids provided") + } + + const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize + var messages [][]byte + + for i := 0; i < len(ids); i += maxPeersPerMessage { + end := i + maxPeersPerMessage + if end > len(ids) { + end = len(ids) + } + chunk := ids[i:end] + + totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize + buf := make([]byte, totalSize) + buf[0] = byte(CurrentProtocolVersion) + buf[1] = msgType + + offset := sizeOfProtoHeader + for _, id := range chunk { + copy(buf[offset:], id[:]) + offset += peerIDSize + } + + messages = append(messages, buf) + } + + return messages, nil +} + +// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer +func unmarshalPeerIDs(buf []byte) ([]PeerID, error) { + if len(buf) < sizeOfProtoHeader { + return nil, fmt.Errorf("invalid message format") + } + + if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 { + return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader) + } + + numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize + + ids := make([]PeerID, numIDs) + offset := sizeOfProtoHeader + for i := 0; i < numIDs; i++ { + copy(ids[i][:], buf[offset:offset+peerIDSize]) + offset += peerIDSize + } + + return ids, nil +} diff --git a/relay/messages/peer_state_test.go b/relay/messages/peer_state_test.go new file mode 100644 index 000000000..9e366da55 --- /dev/null +++ b/relay/messages/peer_state_test.go @@ -0,0 +1,144 @@ +package messages + +import ( + "bytes" + "testing" +) + +const ( + testPeerCount = 10 +) + +// Helper function to generate test PeerIDs +func generateTestPeerIDs(n int) []PeerID { + ids := make([]PeerID, n) + for i := 0; i < n; i++ { + for j := 0; j < peerIDSize; j++ { + ids[i][j] = byte(i + j) + } + } + return ids +} + +// Helper function to compare slices of PeerID +func peerIDEqual(a, b []PeerID) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !bytes.Equal(a[i][:], b[i][:]) { + return false + } + } + return true +} + +func TestMarshalUnmarshalSubPeerState(t *testing.T) { + ids := generateTestPeerIDs(testPeerCount) + + msgs, err := MarshalSubPeerStateMsg(ids) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var allIDs []PeerID + for _, msg := range msgs { + decoded, err := UnmarshalSubPeerStateMsg(msg) + if err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + allIDs = append(allIDs, decoded...) + } + + if !peerIDEqual(ids, allIDs) { + t.Errorf("expected %v, got %v", ids, allIDs) + } +} + +func TestMarshalSubPeerState_EmptyInput(t *testing.T) { + _, err := MarshalSubPeerStateMsg([]PeerID{}) + if err == nil { + t.Errorf("expected error for empty input") + } +} + +func TestUnmarshalSubPeerState_Invalid(t *testing.T) { + // Too short + _, err := UnmarshalSubPeerStateMsg([]byte{1}) + if err == nil { + t.Errorf("expected error for short input") + } + + // Misaligned length + buf := make([]byte, sizeOfProtoHeader+1) + _, err = UnmarshalSubPeerStateMsg(buf) + if err == nil { + t.Errorf("expected error for misaligned input") + } +} + +func TestMarshalUnmarshalPeersOnline(t *testing.T) { + ids := generateTestPeerIDs(testPeerCount) + + msgs, err := MarshalPeersOnline(ids) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var allIDs []PeerID + for _, msg := range msgs { + decoded, err := UnmarshalPeersOnlineMsg(msg) + if err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + allIDs = append(allIDs, decoded...) + } + + if !peerIDEqual(ids, allIDs) { + t.Errorf("expected %v, got %v", ids, allIDs) + } +} + +func TestMarshalPeersOnline_EmptyInput(t *testing.T) { + _, err := MarshalPeersOnline([]PeerID{}) + if err == nil { + t.Errorf("expected error for empty input") + } +} + +func TestUnmarshalPeersOnline_Invalid(t *testing.T) { + _, err := UnmarshalPeersOnlineMsg([]byte{1}) + if err == nil { + t.Errorf("expected error for short input") + } +} + +func TestMarshalUnmarshalPeersWentOffline(t *testing.T) { + ids := generateTestPeerIDs(testPeerCount) + + msgs, err := MarshalPeersWentOffline(ids) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var allIDs []PeerID + for _, msg := range msgs { + // MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline + decoded, err := UnmarshalPeersOnlineMsg(msg) + if err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + allIDs = append(allIDs, decoded...) + } + + if !peerIDEqual(ids, allIDs) { + t.Errorf("expected %v, got %v", ids, allIDs) + } +} + +func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) { + _, err := MarshalPeersWentOffline([]PeerID{}) + if err == nil { + t.Errorf("expected error for empty input") + } +} diff --git a/relay/server/handshake.go b/relay/server/handshake.go index babd6f955..eb72b3bae 100644 --- a/relay/server/handshake.go +++ b/relay/server/handshake.go @@ -6,7 +6,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/messages" //nolint:staticcheck "github.com/netbirdio/netbird/relay/messages/address" @@ -14,6 +13,12 @@ import ( authmsg "github.com/netbirdio/netbird/relay/messages/auth" ) +type Validator interface { + Validate(any) error + // Deprecated: Use Validate instead. + ValidateHelloMsgType(any) error +} + // preparedMsg contains the marshalled success response messages type preparedMsg struct { responseHelloMsg []byte @@ -54,14 +59,14 @@ func marshalResponseHelloMsg(instanceURL string) ([]byte, error) { type handshake struct { conn net.Conn - validator auth.Validator + validator Validator preparedMsg *preparedMsg handshakeMethodAuth bool - peerID string + peerID *messages.PeerID } -func (h *handshake) handshakeReceive() ([]byte, error) { +func (h *handshake) handshakeReceive() (*messages.PeerID, error) { buf := make([]byte, messages.MaxHandshakeSize) n, err := h.conn.Read(buf) if err != nil { @@ -80,17 +85,14 @@ func (h *handshake) handshakeReceive() ([]byte, error) { return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) } - var ( - bytePeerID []byte - peerID string - ) + var peerID *messages.PeerID switch msgType { //nolint:staticcheck case messages.MsgTypeHello: - bytePeerID, peerID, err = h.handleHelloMsg(buf) + peerID, err = h.handleHelloMsg(buf) case messages.MsgTypeAuth: h.handshakeMethodAuth = true - bytePeerID, peerID, err = h.handleAuthMsg(buf) + peerID, err = h.handleAuthMsg(buf) default: return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) } @@ -98,7 +100,7 @@ func (h *handshake) handshakeReceive() ([]byte, error) { return nil, err } h.peerID = peerID - return bytePeerID, nil + return peerID, nil } func (h *handshake) handshakeResponse() error { @@ -116,40 +118,37 @@ func (h *handshake) handshakeResponse() error { return nil } -func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { +func (h *handshake) handleHelloMsg(buf []byte) (*messages.PeerID, error) { //nolint:staticcheck - rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) + peerID, authData, err := messages.UnmarshalHelloMsg(buf) if err != nil { - return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + return nil, fmt.Errorf("unmarshal hello message: %w", err) } - peerID := messages.HashIDToString(rawPeerID) log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) authMsg, err := authmsg.UnmarshalMsg(authData) if err != nil { - return nil, "", fmt.Errorf("unmarshal auth message: %w", err) + return nil, fmt.Errorf("unmarshal auth message: %w", err) } //nolint:staticcheck if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { - return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + return nil, fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) } - return rawPeerID, peerID, nil + return peerID, nil } -func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { +func (h *handshake) handleAuthMsg(buf []byte) (*messages.PeerID, error) { rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) if err != nil { - return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + return nil, fmt.Errorf("unmarshal hello message: %w", err) } - peerID := messages.HashIDToString(rawPeerID) - if err := h.validator.Validate(authPayload); err != nil { - return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + return nil, fmt.Errorf("validate %s (%s): %w", rawPeerID.String(), h.conn.RemoteAddr(), err) } - return rawPeerID, peerID, nil + return rawPeerID, nil } diff --git a/relay/server/peer.go b/relay/server/peer.go index aa9790f63..c6fa8508f 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -12,43 +12,50 @@ import ( "github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/messages" "github.com/netbirdio/netbird/relay/metrics" + "github.com/netbirdio/netbird/relay/server/store" ) const ( - bufferSize = 8820 + bufferSize = messages.MaxMessageSize errCloseConn = "failed to close connection to peer: %s" ) // Peer represents a peer connection type Peer struct { - metrics *metrics.Metrics - log *log.Entry - idS string - idB []byte - conn net.Conn - connMu sync.RWMutex - store *Store + metrics *metrics.Metrics + log *log.Entry + id messages.PeerID + conn net.Conn + connMu sync.RWMutex + store *store.Store + notifier *store.PeerNotifier + + peersListener *store.Listener } // NewPeer creates a new Peer instance and prepare custom logging -func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer { - stringID := messages.HashIDToString(id) - return &Peer{ - metrics: metrics, - log: log.WithField("peer_id", stringID), - idS: stringID, - idB: id, - conn: conn, - store: store, +func NewPeer(metrics *metrics.Metrics, id messages.PeerID, conn net.Conn, store *store.Store, notifier *store.PeerNotifier) *Peer { + p := &Peer{ + metrics: metrics, + log: log.WithField("peer_id", id.String()), + id: id, + conn: conn, + store: store, + notifier: notifier, } + + return p } // Work reads data from the connection // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // the message accordingly. func (p *Peer) Work() { + p.peersListener = p.notifier.NewListener(p.sendPeersOnline, p.sendPeersWentOffline) defer func() { + p.notifier.RemoveListener(p.peersListener) + if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { p.log.Errorf(errCloseConn, err) } @@ -94,6 +101,10 @@ func (p *Peer) Work() { } } +func (p *Peer) ID() messages.PeerID { + return p.id +} + func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) { switch msgType { case messages.MsgTypeHealthCheck: @@ -107,6 +118,10 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc * if err := p.conn.Close(); err != nil { log.Errorf(errCloseConn, err) } + case messages.MsgTypeSubscribePeerState: + p.handleSubscribePeerState(msg) + case messages.MsgTypeUnsubscribePeerState: + p.handleUnsubscribePeerState(msg) default: p.log.Warnf("received unexpected message type: %s", msgType) } @@ -145,7 +160,7 @@ func (p *Peer) Close() { // String returns the peer ID func (p *Peer) String() string { - return p.idS + return p.id.String() } func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error { @@ -197,14 +212,14 @@ func (p *Peer) handleTransportMsg(msg []byte) { return } - stringPeerID := messages.HashIDToString(peerID) - dp, ok := p.store.Peer(stringPeerID) + item, ok := p.store.Peer(*peerID) if !ok { - p.log.Debugf("peer not found: %s", stringPeerID) + p.log.Debugf("peer not found: %s", peerID) return } + dp := item.(*Peer) - err = messages.UpdateTransportMsg(msg, p.idB) + err = messages.UpdateTransportMsg(msg, p.id) if err != nil { p.log.Errorf("failed to update transport message: %s", err) return @@ -217,3 +232,57 @@ func (p *Peer) handleTransportMsg(msg []byte) { } p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) } + +func (p *Peer) handleSubscribePeerState(msg []byte) { + peerIDs, err := messages.UnmarshalSubPeerStateMsg(msg) + if err != nil { + p.log.Errorf("failed to unmarshal open connection message: %s", err) + return + } + + p.log.Debugf("received subscription message for %d peers", len(peerIDs)) + onlinePeers := p.peersListener.AddInterestedPeers(peerIDs) + if len(onlinePeers) == 0 { + return + } + p.log.Debugf("response with %d online peers", len(onlinePeers)) + p.sendPeersOnline(onlinePeers) +} + +func (p *Peer) handleUnsubscribePeerState(msg []byte) { + peerIDs, err := messages.UnmarshalUnsubPeerStateMsg(msg) + if err != nil { + p.log.Errorf("failed to unmarshal open connection message: %s", err) + return + } + + p.peersListener.RemoveInterestedPeer(peerIDs) +} + +func (p *Peer) sendPeersOnline(peers []messages.PeerID) { + msgs, err := messages.MarshalPeersOnline(peers) + if err != nil { + p.log.Errorf("failed to marshal peer location message: %s", err) + return + } + + for n, msg := range msgs { + if _, err := p.Write(msg); err != nil { + p.log.Errorf("failed to write %d. peers offline message: %s", n, err) + } + } +} + +func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) { + msgs, err := messages.MarshalPeersWentOffline(peers) + if err != nil { + p.log.Errorf("failed to marshal peer location message: %s", err) + return + } + + for n, msg := range msgs { + if _, err := p.Write(msg); err != nil { + p.log.Errorf("failed to write %d. peers offline message: %s", n, err) + } + } +} diff --git a/relay/server/relay.go b/relay/server/relay.go index a5e77bc61..93fb00edb 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -4,26 +4,55 @@ import ( "context" "fmt" "net" - "net/url" - "strings" "sync" "time" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/metric" - "github.com/netbirdio/netbird/relay/auth" //nolint:staticcheck "github.com/netbirdio/netbird/relay/metrics" + "github.com/netbirdio/netbird/relay/server/store" ) +type Config struct { + Meter metric.Meter + ExposedAddress string + TLSSupport bool + AuthValidator Validator + + instanceURL string +} + +func (c *Config) validate() error { + if c.Meter == nil { + c.Meter = otel.Meter("") + } + if c.ExposedAddress == "" { + return fmt.Errorf("exposed address is required") + } + + instanceURL, err := getInstanceURL(c.ExposedAddress, c.TLSSupport) + if err != nil { + return fmt.Errorf("invalid url: %v", err) + } + c.instanceURL = instanceURL + + if c.AuthValidator == nil { + return fmt.Errorf("auth validator is required") + } + return nil +} + // Relay represents the relay server type Relay struct { metrics *metrics.Metrics metricsCancel context.CancelFunc - validator auth.Validator + validator Validator - store *Store + store *store.Store + notifier *store.PeerNotifier instanceURL string preparedMsg *preparedMsg @@ -31,40 +60,40 @@ type Relay struct { closeMu sync.RWMutex } -// NewRelay creates a new Relay instance +// NewRelay creates and returns a new Relay instance. // // Parameters: -// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage -// metrics for the relay server. -// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this -// address as the relay server's instance URL. -// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The -// instance URL depends on this value. -// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the -// peers. +// +// config: A Config struct that holds the configuration needed to initialize the relay server. +// - Meter: A metric.Meter used for emitting metrics. If not set, a default no-op meter will be used. +// - ExposedAddress: The external address clients use to reach this relay. Required. +// - TLSSupport: A boolean indicating if the relay uses TLS. Affects the generated instance URL. +// - AuthValidator: A Validator implementation used to authenticate peers. Required. // // Returns: -// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil. -// Otherwise, the error contains the details of what went wrong. -func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) { +// +// A pointer to a Relay instance and an error. If initialization is successful, the error will be nil; +// otherwise, it will contain the reason the relay could not be created (e.g., invalid configuration). +func NewRelay(config Config) (*Relay, error) { + if err := config.validate(); err != nil { + return nil, fmt.Errorf("invalid config: %v", err) + } + ctx, metricsCancel := context.WithCancel(context.Background()) - m, err := metrics.NewMetrics(ctx, meter) + m, err := metrics.NewMetrics(ctx, config.Meter) if err != nil { metricsCancel() return nil, fmt.Errorf("creating app metrics: %v", err) } + peerStore := store.NewStore() r := &Relay{ metrics: m, metricsCancel: metricsCancel, - validator: validator, - store: NewStore(), - } - - r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport) - if err != nil { - metricsCancel() - return nil, fmt.Errorf("get instance URL: %v", err) + validator: config.AuthValidator, + instanceURL: config.instanceURL, + store: peerStore, + notifier: store.NewPeerNotifier(peerStore), } r.preparedMsg, err = newPreparedMsg(r.instanceURL) @@ -76,32 +105,6 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida return r, nil } -// getInstanceURL checks if user supplied a URL scheme otherwise adds to the -// provided address according to TLS definition and parses the address before returning it -func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { - addr := exposedAddress - split := strings.Split(exposedAddress, "://") - switch { - case len(split) == 1 && tlsSupported: - addr = "rels://" + exposedAddress - case len(split) == 1 && !tlsSupported: - addr = "rel://" + exposedAddress - case len(split) > 2: - return "", fmt.Errorf("invalid exposed address: %s", exposedAddress) - } - - parsedURL, err := url.ParseRequestURI(addr) - if err != nil { - return "", fmt.Errorf("invalid exposed address: %v", err) - } - - if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" { - return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) - } - - return parsedURL.String(), nil -} - // Accept start to handle a new peer connection func (r *Relay) Accept(conn net.Conn) { acceptTime := time.Now() @@ -125,14 +128,17 @@ func (r *Relay) Accept(conn net.Conn) { return } - peer := NewPeer(r.metrics, peerID, conn, r.store) + peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier) peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) storeTime := time.Now() r.store.AddPeer(peer) + r.notifier.PeerCameOnline(peer.ID()) + r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.PeerConnected(peer.String()) go func() { peer.Work() + r.notifier.PeerWentOffline(peer.ID()) r.store.DeletePeer(peer) peer.log.Debugf("relay connection closed") r.metrics.PeerDisconnected(peer.String()) @@ -154,12 +160,12 @@ func (r *Relay) Shutdown(ctx context.Context) { wg := sync.WaitGroup{} peers := r.store.Peers() - for _, peer := range peers { + for _, v := range peers { wg.Add(1) go func(p *Peer) { p.CloseGracefully(ctx) wg.Done() - }(peer) + }(v.(*Peer)) } wg.Wait() r.metricsCancel() diff --git a/relay/server/server.go b/relay/server/server.go index 10aabcace..f0b480ee4 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -6,15 +6,12 @@ import ( "sync" "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel/metric" - nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/relay/auth" "github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener/quic" "github.com/netbirdio/netbird/relay/server/listener/ws" quictls "github.com/netbirdio/netbird/relay/tls" + log "github.com/sirupsen/logrus" ) // ListenerConfig is the configuration for the listener. @@ -33,13 +30,22 @@ type Server struct { listeners []listener.Listener } -// NewServer creates a new relay server instance. -// meter: the OpenTelemetry meter -// exposedAddress: this address will be used as the instance URL. It should be a domain:port format. -// tlsSupport: if true, the server will support TLS -// authValidator: the auth validator to use for the server -func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) { - relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator) +// NewServer creates and returns a new relay server instance. +// +// Parameters: +// +// config: A Config struct containing the necessary configuration: +// - Meter: An OpenTelemetry metric.Meter used for recording metrics. If nil, a default no-op meter is used. +// - ExposedAddress: The public address (in domain:port format) used as the server's instance URL. Required. +// - TLSSupport: A boolean indicating whether TLS is enabled for the server. +// - AuthValidator: A Validator used to authenticate peers. Required. +// +// Returns: +// +// A pointer to a Server instance and an error. If the configuration is valid and initialization succeeds, +// the returned error will be nil. Otherwise, the error will describe the problem. +func NewServer(config Config) (*Server, error) { + relay, err := NewRelay(config) if err != nil { return nil, err } diff --git a/relay/server/store/listener.go b/relay/server/store/listener.go new file mode 100644 index 000000000..e5f455795 --- /dev/null +++ b/relay/server/store/listener.go @@ -0,0 +1,121 @@ +package store + +import ( + "context" + "sync" + + "github.com/netbirdio/netbird/relay/messages" +) + +type Listener struct { + store *Store + + onlineChan chan messages.PeerID + offlineChan chan messages.PeerID + interestedPeersForOffline map[messages.PeerID]struct{} + interestedPeersForOnline map[messages.PeerID]struct{} + mu sync.RWMutex + + listenerCtx context.Context +} + +func newListener(store *Store) *Listener { + l := &Listener{ + store: store, + + onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol + offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol + interestedPeersForOffline: make(map[messages.PeerID]struct{}), + interestedPeersForOnline: make(map[messages.PeerID]struct{}), + } + + return l +} + +func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID { + availablePeers := make([]messages.PeerID, 0) + l.mu.Lock() + defer l.mu.Unlock() + + for _, id := range peerIDs { + l.interestedPeersForOnline[id] = struct{}{} + l.interestedPeersForOffline[id] = struct{}{} + } + + // collect online peers to response back to the caller + for _, id := range peerIDs { + _, ok := l.store.Peer(id) + if !ok { + continue + } + + availablePeers = append(availablePeers, id) + } + return availablePeers +} + +func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) { + l.mu.Lock() + defer l.mu.Unlock() + + for _, id := range peerIDs { + delete(l.interestedPeersForOffline, id) + delete(l.interestedPeersForOnline, id) + + } +} + +func (l *Listener) listenForEvents(ctx context.Context, onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) { + l.listenerCtx = ctx + for { + select { + case <-ctx.Done(): + return + case pID := <-l.onlineChan: + peers := make([]messages.PeerID, 0) + peers = append(peers, pID) + + for len(l.onlineChan) > 0 { + pID = <-l.onlineChan + peers = append(peers, pID) + } + + onPeersComeOnline(peers) + case pID := <-l.offlineChan: + peers := make([]messages.PeerID, 0) + peers = append(peers, pID) + + for len(l.offlineChan) > 0 { + pID = <-l.offlineChan + peers = append(peers, pID) + } + + onPeersWentOffline(peers) + } + } +} + +func (l *Listener) peerWentOffline(peerID messages.PeerID) { + l.mu.RLock() + defer l.mu.RUnlock() + + if _, ok := l.interestedPeersForOffline[peerID]; ok { + select { + case l.offlineChan <- peerID: + case <-l.listenerCtx.Done(): + } + } +} + +func (l *Listener) peerComeOnline(peerID messages.PeerID) { + l.mu.Lock() + defer l.mu.Unlock() + + if _, ok := l.interestedPeersForOnline[peerID]; ok { + select { + case l.onlineChan <- peerID: + case <-l.listenerCtx.Done(): + } + delete(l.interestedPeersForOnline, peerID) + } +} diff --git a/relay/server/store/notifier.go b/relay/server/store/notifier.go new file mode 100644 index 000000000..d04db478b --- /dev/null +++ b/relay/server/store/notifier.go @@ -0,0 +1,64 @@ +package store + +import ( + "context" + "sync" + + "github.com/netbirdio/netbird/relay/messages" +) + +type PeerNotifier struct { + store *Store + + listeners map[*Listener]context.CancelFunc + listenersMutex sync.RWMutex +} + +func NewPeerNotifier(store *Store) *PeerNotifier { + pn := &PeerNotifier{ + store: store, + listeners: make(map[*Listener]context.CancelFunc), + } + return pn +} + +func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener { + ctx, cancel := context.WithCancel(context.Background()) + listener := newListener(pn.store) + go listener.listenForEvents(ctx, onPeersComeOnline, onPeersWentOffline) + + pn.listenersMutex.Lock() + pn.listeners[listener] = cancel + pn.listenersMutex.Unlock() + return listener +} + +func (pn *PeerNotifier) RemoveListener(listener *Listener) { + pn.listenersMutex.Lock() + defer pn.listenersMutex.Unlock() + + cancel, ok := pn.listeners[listener] + if !ok { + return + } + cancel() + delete(pn.listeners, listener) +} + +func (pn *PeerNotifier) PeerWentOffline(peerID messages.PeerID) { + pn.listenersMutex.RLock() + defer pn.listenersMutex.RUnlock() + + for listener := range pn.listeners { + listener.peerWentOffline(peerID) + } +} + +func (pn *PeerNotifier) PeerCameOnline(peerID messages.PeerID) { + pn.listenersMutex.RLock() + defer pn.listenersMutex.RUnlock() + + for listener := range pn.listeners { + listener.peerComeOnline(peerID) + } +} diff --git a/relay/server/store.go b/relay/server/store/store.go similarity index 61% rename from relay/server/store.go rename to relay/server/store/store.go index 4288e62c5..c19fb416f 100644 --- a/relay/server/store.go +++ b/relay/server/store/store.go @@ -1,41 +1,48 @@ -package server +package store import ( "sync" + + "github.com/netbirdio/netbird/relay/messages" ) +type IPeer interface { + Close() + ID() messages.PeerID +} + // Store is a thread-safe store of peers // It is used to store the peers that are connected to the relay server type Store struct { - peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster + peers map[messages.PeerID]IPeer peersLock sync.RWMutex } // NewStore creates a new Store instance func NewStore() *Store { return &Store{ - peers: make(map[string]*Peer), + peers: make(map[messages.PeerID]IPeer), } } // AddPeer adds a peer to the store -func (s *Store) AddPeer(peer *Peer) { +func (s *Store) AddPeer(peer IPeer) { s.peersLock.Lock() defer s.peersLock.Unlock() - odlPeer, ok := s.peers[peer.String()] + odlPeer, ok := s.peers[peer.ID()] if ok { odlPeer.Close() } - s.peers[peer.String()] = peer + s.peers[peer.ID()] = peer } // DeletePeer deletes a peer from the store -func (s *Store) DeletePeer(peer *Peer) { +func (s *Store) DeletePeer(peer IPeer) { s.peersLock.Lock() defer s.peersLock.Unlock() - dp, ok := s.peers[peer.String()] + dp, ok := s.peers[peer.ID()] if !ok { return } @@ -43,11 +50,11 @@ func (s *Store) DeletePeer(peer *Peer) { return } - delete(s.peers, peer.String()) + delete(s.peers, peer.ID()) } // Peer returns a peer by its ID -func (s *Store) Peer(id string) (*Peer, bool) { +func (s *Store) Peer(id messages.PeerID) (IPeer, bool) { s.peersLock.RLock() defer s.peersLock.RUnlock() @@ -56,11 +63,11 @@ func (s *Store) Peer(id string) (*Peer, bool) { } // Peers returns all the peers in the store -func (s *Store) Peers() []*Peer { +func (s *Store) Peers() []IPeer { s.peersLock.RLock() defer s.peersLock.RUnlock() - peers := make([]*Peer, 0, len(s.peers)) + peers := make([]IPeer, 0, len(s.peers)) for _, p := range s.peers { peers = append(peers, p) } diff --git a/relay/server/store/store_test.go b/relay/server/store/store_test.go new file mode 100644 index 000000000..ad549a62c --- /dev/null +++ b/relay/server/store/store_test.go @@ -0,0 +1,49 @@ +package store + +import ( + "testing" + + "github.com/netbirdio/netbird/relay/messages" +) + +type MocPeer struct { + id messages.PeerID +} + +func (m *MocPeer) Close() { + +} + +func (m *MocPeer) ID() messages.PeerID { + return m.id +} + +func TestStore_DeletePeer(t *testing.T) { + s := NewStore() + + pID := messages.HashID("peer_one") + p := &MocPeer{id: pID} + s.AddPeer(p) + s.DeletePeer(p) + if _, ok := s.Peer(pID); ok { + t.Errorf("peer was not deleted") + } +} + +func TestStore_DeleteDeprecatedPeer(t *testing.T) { + s := NewStore() + + pID1 := messages.HashID("peer_one") + pID2 := messages.HashID("peer_one") + + p1 := &MocPeer{id: pID1} + p2 := &MocPeer{id: pID2} + + s.AddPeer(p1) + s.AddPeer(p2) + s.DeletePeer(p1) + + if _, ok := s.Peer(pID2); !ok { + t.Errorf("second peer was deleted") + } +} diff --git a/relay/server/store_test.go b/relay/server/store_test.go deleted file mode 100644 index 41c7baa92..000000000 --- a/relay/server/store_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package server - -import ( - "context" - "net" - "testing" - "time" - - "go.opentelemetry.io/otel" - - "github.com/netbirdio/netbird/relay/metrics" -) - -type mockConn struct { -} - -func (m mockConn) Read(b []byte) (n int, err error) { - //TODO implement me - panic("implement me") -} - -func (m mockConn) Write(b []byte) (n int, err error) { - //TODO implement me - panic("implement me") -} - -func (m mockConn) Close() error { - return nil -} - -func (m mockConn) LocalAddr() net.Addr { - //TODO implement me - panic("implement me") -} - -func (m mockConn) RemoteAddr() net.Addr { - //TODO implement me - panic("implement me") -} - -func (m mockConn) SetDeadline(t time.Time) error { - //TODO implement me - panic("implement me") -} - -func (m mockConn) SetReadDeadline(t time.Time) error { - //TODO implement me - panic("implement me") -} - -func (m mockConn) SetWriteDeadline(t time.Time) error { - //TODO implement me - panic("implement me") -} - -func TestStore_DeletePeer(t *testing.T) { - s := NewStore() - - m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) - - p := NewPeer(m, []byte("peer_one"), nil, nil) - s.AddPeer(p) - s.DeletePeer(p) - if _, ok := s.Peer(p.String()); ok { - t.Errorf("peer was not deleted") - } -} - -func TestStore_DeleteDeprecatedPeer(t *testing.T) { - s := NewStore() - - m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) - - conn := &mockConn{} - p1 := NewPeer(m, []byte("peer_id"), conn, nil) - p2 := NewPeer(m, []byte("peer_id"), conn, nil) - - s.AddPeer(p1) - s.AddPeer(p2) - s.DeletePeer(p1) - - if _, ok := s.Peer(p2.String()); !ok { - t.Errorf("second peer was deleted") - } -} diff --git a/relay/server/url.go b/relay/server/url.go new file mode 100644 index 000000000..9cbf44642 --- /dev/null +++ b/relay/server/url.go @@ -0,0 +1,33 @@ +package server + +import ( + "fmt" + "net/url" + "strings" +) + +// getInstanceURL checks if user supplied a URL scheme otherwise adds to the +// provided address according to TLS definition and parses the address before returning it +func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { + addr := exposedAddress + split := strings.Split(exposedAddress, "://") + switch { + case len(split) == 1 && tlsSupported: + addr = "rels://" + exposedAddress + case len(split) == 1 && !tlsSupported: + addr = "rel://" + exposedAddress + case len(split) > 2: + return "", fmt.Errorf("invalid exposed address: %s", exposedAddress) + } + + parsedURL, err := url.ParseRequestURI(addr) + if err != nil { + return "", fmt.Errorf("invalid exposed address: %v", err) + } + + if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" { + return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) + } + + return parsedURL.String(), nil +} diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go index ec2aa488c..2e67ab803 100644 --- a/relay/test/benchmark_test.go +++ b/relay/test/benchmark_test.go @@ -12,7 +12,6 @@ import ( "github.com/pion/logging" "github.com/pion/turn/v3" - "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/relay/auth/allow" "github.com/netbirdio/netbird/relay/auth/hmac" @@ -22,7 +21,6 @@ import ( ) var ( - av = &allow.Auth{} hmacTokenStore = &hmac.TokenStore{} pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100} dataSize = 1024 * 1024 * 10 @@ -70,8 +68,12 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { port := 35000 + peerPairs serverAddress := fmt.Sprintf("127.0.0.1:%d", port) serverConnURL := fmt.Sprintf("rel://%s", serverAddress) - - srv, err := server.NewServer(otel.Meter(""), serverConnURL, false, av) + serverCfg := server.Config{ + ExposedAddress: serverConnURL, + TLSSupport: false, + AuthValidator: &allow.Auth{}, + } + srv, err := server.NewServer(serverCfg) if err != nil { t.Fatalf("failed to create server: %s", err) } @@ -98,8 +100,8 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { clientsSender := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsSender); i++ { - c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) - err := c.Connect() + c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) + err := c.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -108,8 +110,8 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { clientsReceiver := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsReceiver); i++ { - c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) - err := c.Connect() + c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) + err := c.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } @@ -119,13 +121,13 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { connsSender := make([]net.Conn, 0, peerPairs) connsReceiver := make([]net.Conn, 0, peerPairs) for i := 0; i < len(clientsSender); i++ { - conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) + conn, err := clientsSender[i].OpenConn(ctx, "receiver-"+fmt.Sprint(i)) if err != nil { t.Fatalf("failed to bind channel: %s", err) } connsSender = append(connsSender, conn) - conn, err = clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) + conn, err = clientsReceiver[i].OpenConn(ctx, "sender-"+fmt.Sprint(i)) if err != nil { t.Fatalf("failed to bind channel: %s", err) } diff --git a/relay/testec2/relay.go b/relay/testec2/relay.go index 93d084387..9e22a80ea 100644 --- a/relay/testec2/relay.go +++ b/relay/testec2/relay.go @@ -70,8 +70,8 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn { ctx := context.Background() clientsSender := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsSender); i++ { - c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) - if err := c.Connect(); err != nil { + c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) + if err := c.Connect(ctx); err != nil { log.Fatalf("failed to connect to server: %s", err) } clientsSender[i] = c @@ -79,7 +79,7 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn { connsSender := make([]net.Conn, 0, peerPairs) for i := 0; i < len(clientsSender); i++ { - conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) + conn, err := clientsSender[i].OpenConn(ctx, "receiver-"+fmt.Sprint(i)) if err != nil { log.Fatalf("failed to bind channel: %s", err) } @@ -156,8 +156,8 @@ func runReader(conn net.Conn) time.Duration { func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn { clientsReceiver := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsReceiver); i++ { - c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) - err := c.Connect() + c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) + err := c.Connect(context.Background()) if err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -166,7 +166,7 @@ func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn { connsReceiver := make([]net.Conn, 0, peerPairs) for i := 0; i < len(clientsReceiver); i++ { - conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) + conn, err := clientsReceiver[i].OpenConn(context.Background(), "sender-"+fmt.Sprint(i)) if err != nil { log.Fatalf("failed to bind channel: %s", err) }