mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-12 18:00:49 +01:00
762a26dcea
This PR fixes a race condition that happens when agents connect to a Signal stream, multiple times within a short amount of time. Common on slow and unstable internet connections. Every time an agent establishes a new connection to Signal, Signal creates a Stream and writes an entry to the registry of connected peers storing the stream. Every time an agent disconnects, Signal removes the stream from the registry. Due to unstable connections, the agent could detect a broken connection, and attempt to reconnect to Signal. Signal will override the stream, but it might detect the old broken connection later, causing peer deregistration. It will deregister the peer leaving the client thinking it is still connected, rejecting any messages.
113 lines
3.4 KiB
Go
113 lines
3.4 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/netbirdio/netbird/signal/peer"
|
|
"github.com/netbirdio/netbird/signal/proto"
|
|
log "github.com/sirupsen/logrus"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/status"
|
|
"io"
|
|
)
|
|
|
|
// Server an instance of a Signal server
|
|
type Server struct {
|
|
registry *peer.Registry
|
|
proto.UnimplementedSignalExchangeServer
|
|
}
|
|
|
|
// NewServer creates a new Signal server
|
|
func NewServer() *Server {
|
|
return &Server{
|
|
registry: peer.NewRegistry(),
|
|
}
|
|
}
|
|
|
|
// Send forwards a message to the signal peer
|
|
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
|
|
|
if !s.registry.IsPeerRegistered(msg.Key) {
|
|
return nil, fmt.Errorf("peer %s is not registered", msg.Key)
|
|
}
|
|
|
|
if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
|
|
//forward the message to the target peer
|
|
err := dstPeer.Stream.Send(msg)
|
|
if err != nil {
|
|
log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
|
|
//todo respond to the sender?
|
|
}
|
|
} else {
|
|
log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
|
|
//todo respond to the sender?
|
|
}
|
|
return &proto.EncryptedMessage{}, nil
|
|
}
|
|
|
|
// ConnectStream connects to the exchange stream
|
|
func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
|
|
|
|
p, err := s.connectPeer(stream)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defer func() {
|
|
log.Infof("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
|
|
s.registry.Deregister(p)
|
|
}()
|
|
|
|
//needed to confirm that the peer has been registered so that the client can proceed
|
|
header := metadata.Pairs(proto.HeaderRegistered, "1")
|
|
err = stream.SendHeader(header)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Infof("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)
|
|
|
|
for {
|
|
//read incoming messages
|
|
msg, err := stream.Recv()
|
|
if err == io.EOF {
|
|
break
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey)
|
|
// lookup the target peer where the message is going to
|
|
if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
|
|
//forward the message to the target peer
|
|
err := dstPeer.Stream.Send(msg)
|
|
if err != nil {
|
|
log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err)
|
|
//todo respond to the sender?
|
|
}
|
|
} else {
|
|
log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey)
|
|
//todo respond to the sender?
|
|
}
|
|
}
|
|
<-stream.Context().Done()
|
|
return stream.Context().Err()
|
|
}
|
|
|
|
// Handles initial Peer connection.
|
|
// Each connection must provide an Id header.
|
|
// At this moment the connecting Peer will be registered in the peer.Registry
|
|
func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
|
|
if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta {
|
|
if id, found := meta[proto.HeaderId]; found {
|
|
p := peer.NewPeer(id[0], stream)
|
|
s.registry.Register(p)
|
|
return p, nil
|
|
} else {
|
|
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: "+proto.HeaderId)
|
|
}
|
|
} else {
|
|
return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta")
|
|
}
|
|
}
|