mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-19 11:20:18 +02:00
Add client part of the relay version changes
This commit is contained in:
@@ -1,13 +0,0 @@
|
||||
package client
|
||||
|
||||
type RelayAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a RelayAddr) Network() string {
|
||||
return "relay"
|
||||
}
|
||||
|
||||
func (a RelayAddr) String() string {
|
||||
return a.addr
|
||||
}
|
@@ -136,7 +136,7 @@ type Client struct {
|
||||
mu sync.Mutex // protect serviceIsRunning and conns
|
||||
readLoopMutex sync.Mutex
|
||||
wgReadLoop sync.WaitGroup
|
||||
instanceURL *RelayAddr
|
||||
instanceURL *messages.RelayAddr
|
||||
muInstanceURL sync.Mutex
|
||||
|
||||
onDisconnectListener func(string)
|
||||
@@ -189,7 +189,11 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||
c.instanceURL = instanceURL
|
||||
c.muInstanceURL.Unlock()
|
||||
|
||||
c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
|
||||
if c.instanceURL.FeatureVersionCode < messages.VersionSubscription {
|
||||
c.log.Warnf("server is deprecated, peer state subscription feature will not work")
|
||||
} else {
|
||||
c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
|
||||
}
|
||||
|
||||
c.log = c.log.WithField("relay", instanceURL.String())
|
||||
c.log.Infof("relay connection established")
|
||||
@@ -291,7 +295,7 @@ func (c *Client) Close() error {
|
||||
return c.close(true)
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
func (c *Client) connect(ctx context.Context) (*messages.RelayAddr, error) {
|
||||
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
@@ -311,7 +315,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
return instanceURL, nil
|
||||
}
|
||||
|
||||
func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
|
||||
func (c *Client) handShake(ctx context.Context) (*messages.RelayAddr, error) {
|
||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to marshal auth message: %s", err)
|
||||
@@ -346,12 +350,16 @@ func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
|
||||
return nil, fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
addr, err := messages.UnmarshalAuthResponse(buf[:n])
|
||||
payload, err := messages.UnmarshalAuthResponse(buf[:n])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RelayAddr{addr: addr}, nil
|
||||
relayAddr, err := messages.UnmarshalRelayAddr(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return relayAddr, nil
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
|
||||
@@ -411,10 +419,16 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
|
||||
case messages.MsgTypeTransport:
|
||||
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
||||
case messages.MsgTypePeersOnline:
|
||||
if c.stateSubscription == nil {
|
||||
c.log.Warnf("message type %s is not supported by the server, peer state subscription feature is not available)", msgType)
|
||||
}
|
||||
c.handlePeersOnlineMsg(buf)
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
case messages.MsgTypePeersWentOffline:
|
||||
if c.stateSubscription == nil {
|
||||
c.log.Warnf("message type %s is not supported by the server, peer state subscription feature is not available)", msgType)
|
||||
}
|
||||
c.handlePeersWentOfflineMsg(buf)
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
|
@@ -46,6 +46,10 @@ func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offl
|
||||
// OnPeersOnline should be called when a notification is received that certain peers have come online.
|
||||
// It checks if any of the peers are being waited on and signals their availability.
|
||||
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -63,6 +67,10 @@ func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
|
||||
}
|
||||
|
||||
func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
relevantPeers := make([]messages.PeerID, 0, len(peersID))
|
||||
for _, peerID := range peersID {
|
||||
@@ -79,6 +87,9 @@ func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
|
||||
|
||||
// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
|
||||
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
// Check if already waiting for this peer
|
||||
s.mu.Lock()
|
||||
if _, exists := s.waitingPeers[peerID]; exists {
|
||||
@@ -132,6 +143,10 @@ func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context,
|
||||
}
|
||||
|
||||
func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
msgErr := s.unsubscribeStateChange(peerIDs)
|
||||
|
||||
s.mu.Lock()
|
||||
@@ -149,6 +164,10 @@ func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerI
|
||||
}
|
||||
|
||||
func (s *PeersStateSubscription) Cleanup() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
@@ -25,7 +25,7 @@ func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) {
|
||||
mockConn := &mockRelayedConn{}
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(&bytes.Buffer{}) // discard log output
|
||||
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||
sub, _ := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil, 0)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@@ -45,7 +45,7 @@ func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) {
|
||||
mockConn := &mockRelayedConn{}
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(&bytes.Buffer{})
|
||||
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||
sub, _ := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
@@ -60,7 +60,7 @@ func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) {
|
||||
mockConn := &mockRelayedConn{}
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(&bytes.Buffer{})
|
||||
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||
sub, _ := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil, 0)
|
||||
|
||||
ctx := context.Background()
|
||||
go func() {
|
||||
@@ -78,7 +78,7 @@ func TestUnsubscribeStateChange(t *testing.T) {
|
||||
mockConn := &mockRelayedConn{}
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(&bytes.Buffer{})
|
||||
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||
sub, _ := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil, 0)
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
go func() {
|
||||
|
56
relay/messages/addr.go
Normal file
56
relay/messages/addr.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type FeatureVersionCode uint16
|
||||
|
||||
const (
|
||||
VersionUnknown FeatureVersionCode = 0
|
||||
VersionSubscription FeatureVersionCode = 1
|
||||
)
|
||||
|
||||
type RelayAddr struct {
|
||||
Addr string `json:"ExposedAddr,omitempty"`
|
||||
FeatureVersionCode FeatureVersionCode `json:"Version,omitempty"`
|
||||
}
|
||||
|
||||
func (a RelayAddr) Network() string {
|
||||
return "relay"
|
||||
}
|
||||
|
||||
func (a RelayAddr) String() string {
|
||||
return a.Addr
|
||||
}
|
||||
|
||||
// UnmarshalRelayAddr json encoded RelayAddr data.
|
||||
func UnmarshalRelayAddr(data []byte) (*RelayAddr, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, fmt.Errorf("unmarshalRelayAddr: empty data")
|
||||
}
|
||||
|
||||
var addr RelayAddr
|
||||
if err := json.Unmarshal(data, &addr); err != nil {
|
||||
addrString, err := fallbackToOldFormat(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fallback to old auth message: %v", err)
|
||||
}
|
||||
return &RelayAddr{Addr: addrString}, nil
|
||||
}
|
||||
|
||||
if addr.Addr == "" {
|
||||
return nil, fmt.Errorf("unmarshalRelayAddr: empty address in RelayAddr")
|
||||
}
|
||||
return &addr, nil
|
||||
}
|
||||
|
||||
func fallbackToOldFormat(data []byte) (string, error) {
|
||||
addr := string(data)
|
||||
if !strings.HasPrefix(addr, "rel://") && !strings.HasPrefix(addr, "rels://") {
|
||||
return "", fmt.Errorf("invalid address: must start with rel:// or rels://: %s", addr)
|
||||
}
|
||||
return addr, nil
|
||||
}
|
@@ -11,7 +11,7 @@ const (
|
||||
MaxHandshakeRespSize = 8192
|
||||
MaxMessageSize = 8820
|
||||
|
||||
CurrentProtocolVersion = 1
|
||||
CurrentProtocolVersion = 2
|
||||
|
||||
MsgTypeUnknown MsgType = 0
|
||||
// Deprecated: Use MsgTypeAuth instead.
|
||||
@@ -264,11 +264,11 @@ func MarshalAuthResponse(address string) ([]byte, error) {
|
||||
}
|
||||
|
||||
// UnmarshalAuthResponse it is a confirmation message to auth success
|
||||
func UnmarshalAuthResponse(msg []byte) (string, error) {
|
||||
func UnmarshalAuthResponse(msg []byte) ([]byte, error) {
|
||||
if len(msg) < sizeOfProtoHeader+1 {
|
||||
return "", ErrInvalidMessageLength
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return string(msg[sizeOfProtoHeader:]), nil
|
||||
return msg[sizeOfProtoHeader:], nil
|
||||
}
|
||||
|
||||
// MarshalCloseMsg creates a close message.
|
||||
|
@@ -74,7 +74,7 @@ func TestMarshalAuthResponse(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if respAddr != address {
|
||||
if string(respAddr) != address {
|
||||
t.Errorf("expected %s, got %s", address, respAddr)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user