mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-16 01:58:16 +02:00
[client, relay-server] Feature/relay notification (#4083)
- Clients now subscribe to peer status changes. - The server manages and maintains these subscriptions. - Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity.
This commit is contained in:
@ -124,15 +124,14 @@ func (cc *connContainer) close() {
|
||||
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
parentCtx context.Context
|
||||
connectionURL string
|
||||
authTokenStore *auth.TokenStore
|
||||
hashedID []byte
|
||||
hashedID messages.PeerID
|
||||
|
||||
bufPool *sync.Pool
|
||||
|
||||
relayConn net.Conn
|
||||
conns map[string]*connContainer
|
||||
conns map[messages.PeerID]*connContainer
|
||||
serviceIsRunning bool
|
||||
mu sync.Mutex // protect serviceIsRunning and conns
|
||||
readLoopMutex sync.Mutex
|
||||
@ -142,14 +141,17 @@ type Client struct {
|
||||
|
||||
onDisconnectListener func(string)
|
||||
listenerMutex sync.Mutex
|
||||
|
||||
stateSubscription *PeersStateSubscription
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||
hashedID, hashedStringId := messages.HashID(peerID)
|
||||
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||
hashedID := messages.HashID(peerID)
|
||||
relayLog := log.WithFields(log.Fields{"relay": serverURL})
|
||||
|
||||
c := &Client{
|
||||
log: log.WithFields(log.Fields{"relay": serverURL}),
|
||||
parentCtx: ctx,
|
||||
log: relayLog,
|
||||
connectionURL: serverURL,
|
||||
authTokenStore: authTokenStore,
|
||||
hashedID: hashedID,
|
||||
@ -159,14 +161,15 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
conns: make(map[string]*connContainer),
|
||||
conns: make(map[messages.PeerID]*connContainer),
|
||||
}
|
||||
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
|
||||
|
||||
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID)
|
||||
return c
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||
func (c *Client) Connect() error {
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
c.log.Infof("connecting to relay server")
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
@ -178,17 +181,23 @@ func (c *Client) Connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.connect(); err != nil {
|
||||
if err := c.connect(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
|
||||
|
||||
c.log = c.log.WithField("relay", c.instanceURL.String())
|
||||
c.log.Infof("relay connection established")
|
||||
|
||||
c.serviceIsRunning = true
|
||||
|
||||
internallyStoppedFlag := newInternalStopFlag()
|
||||
hc := healthcheck.NewReceiver(c.log)
|
||||
go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag)
|
||||
|
||||
c.wgReadLoop.Add(1)
|
||||
go c.readLoop(c.relayConn)
|
||||
go c.readLoop(hc, c.relayConn, internallyStoppedFlag)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -196,26 +205,41 @@ func (c *Client) Connect() error {
|
||||
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
|
||||
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
|
||||
// it will return immediately.
|
||||
// It block until the server confirm the peer is online.
|
||||
// todo: what should happen if call with the same peerID with multiple times?
|
||||
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
|
||||
peerID := messages.HashID(dstPeerID)
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
|
||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
||||
_, ok := c.conns[hashedStringID]
|
||||
_, ok := c.conns[peerID]
|
||||
if ok {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrConnAlreadyExists
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.log.Infof("open connection to peer: %s", hashedStringID)
|
||||
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
|
||||
c.log.Errorf("peer not available: %s, %s", peerID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
|
||||
msgChannel := make(chan Msg, 100)
|
||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
|
||||
conn := NewConn(c, peerID, msgChannel, c.instanceURL)
|
||||
|
||||
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
|
||||
c.mu.Lock()
|
||||
_, ok = c.conns[peerID]
|
||||
if ok {
|
||||
c.mu.Unlock()
|
||||
_ = conn.Close()
|
||||
return nil, ErrConnAlreadyExists
|
||||
}
|
||||
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
|
||||
c.mu.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@ -254,7 +278,7 @@ func (c *Client) Close() error {
|
||||
return c.close(true)
|
||||
}
|
||||
|
||||
func (c *Client) connect() error {
|
||||
func (c *Client) connect(ctx context.Context) error {
|
||||
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
@ -262,7 +286,7 @@ func (c *Client) connect() error {
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
if err = c.handShake(); err != nil {
|
||||
if err = c.handShake(ctx); err != nil {
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
c.log.Errorf("failed to close connection: %s", cErr)
|
||||
@ -273,7 +297,7 @@ func (c *Client) connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handShake() error {
|
||||
func (c *Client) handShake(ctx context.Context) error {
|
||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to marshal auth message: %s", err)
|
||||
@ -286,7 +310,7 @@ func (c *Client) handShake() error {
|
||||
return err
|
||||
}
|
||||
buf := make([]byte, messages.MaxHandshakeRespSize)
|
||||
n, err := c.readWithTimeout(buf)
|
||||
n, err := c.readWithTimeout(ctx, buf)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to read auth response: %s", err)
|
||||
return err
|
||||
@ -319,11 +343,7 @@ func (c *Client) handShake() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(relayConn net.Conn) {
|
||||
internallyStoppedFlag := newInternalStopFlag()
|
||||
hc := healthcheck.NewReceiver(c.log)
|
||||
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
|
||||
|
||||
func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
|
||||
var (
|
||||
errExit error
|
||||
n int
|
||||
@ -370,6 +390,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
|
||||
c.instanceURL = nil
|
||||
c.muInstanceURL.Unlock()
|
||||
|
||||
c.stateSubscription.Cleanup()
|
||||
c.wgReadLoop.Done()
|
||||
_ = c.close(false)
|
||||
c.notifyDisconnected()
|
||||
@ -382,6 +403,14 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
|
||||
c.bufPool.Put(bufPtr)
|
||||
case messages.MsgTypeTransport:
|
||||
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
||||
case messages.MsgTypePeersOnline:
|
||||
c.handlePeersOnlineMsg(buf)
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
case messages.MsgTypePeersWentOffline:
|
||||
c.handlePeersWentOfflineMsg(buf)
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
case messages.MsgTypeClose:
|
||||
c.log.Debugf("relay connection close by server")
|
||||
c.bufPool.Put(bufPtr)
|
||||
@ -413,18 +442,16 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
|
||||
return true
|
||||
}
|
||||
|
||||
stringID := messages.HashIDToString(peerID)
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
c.bufPool.Put(bufPtr)
|
||||
return false
|
||||
}
|
||||
container, ok := c.conns[stringID]
|
||||
container, ok := c.conns[*peerID]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
c.log.Errorf("peer not found: %s", stringID)
|
||||
c.log.Errorf("peer not found: %s", peerID.String())
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
}
|
||||
@ -437,9 +464,9 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
|
||||
func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
conn, ok := c.conns[id]
|
||||
conn, ok := c.conns[dstID]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
@ -464,7 +491,7 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-hc.OnTimeout:
|
||||
@ -478,7 +505,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
|
||||
c.log.Warnf("failed to close connection: %s", err)
|
||||
}
|
||||
return
|
||||
case <-c.parentCtx.Done():
|
||||
case <-ctx.Done():
|
||||
err := c.close(true)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to teardown connection: %s", err)
|
||||
@ -492,10 +519,31 @@ func (c *Client) closeAllConns() {
|
||||
for _, container := range c.conns {
|
||||
container.close()
|
||||
}
|
||||
c.conns = make(map[string]*connContainer)
|
||||
c.conns = make(map[messages.PeerID]*connContainer)
|
||||
}
|
||||
|
||||
func (c *Client) closeConn(connReference *Conn, id string) error {
|
||||
func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
container, ok := c.conns[peerID]
|
||||
if !ok {
|
||||
c.log.Warnf("can not close connection, peer not found: %s", peerID)
|
||||
continue
|
||||
}
|
||||
|
||||
container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID)
|
||||
container.close()
|
||||
delete(c.conns, peerID)
|
||||
}
|
||||
|
||||
if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil {
|
||||
c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@ -507,6 +555,11 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
|
||||
if container.conn != connReference {
|
||||
return fmt.Errorf("conn reference mismatch")
|
||||
}
|
||||
|
||||
if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil {
|
||||
container.log.Errorf("failed to unsubscribe from peer state change: %s", err)
|
||||
}
|
||||
|
||||
c.log.Infof("free up connection to peer: %s", id)
|
||||
delete(c.conns, id)
|
||||
container.close()
|
||||
@ -559,8 +612,8 @@ func (c *Client) writeCloseMsg() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) readWithTimeout(buf []byte) (int, error) {
|
||||
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
|
||||
func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout)
|
||||
defer cancel()
|
||||
|
||||
readDone := make(chan struct{})
|
||||
@ -581,3 +634,21 @@ func (c *Client) readWithTimeout(buf []byte) (int, error) {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handlePeersOnlineMsg(buf []byte) {
|
||||
peersID, err := messages.UnmarshalPeersOnlineMsg(buf)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to unmarshal peers online msg: %s", err)
|
||||
return
|
||||
}
|
||||
c.stateSubscription.OnPeersOnline(peersID)
|
||||
}
|
||||
|
||||
func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
|
||||
peersID, err := messages.UnMarshalPeersWentOffline(buf)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to unmarshal peers went offline msg: %s", err)
|
||||
return
|
||||
}
|
||||
c.stateSubscription.OnPeersWentOffline(peersID)
|
||||
}
|
||||
|
Reference in New Issue
Block a user