mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-05 23:08:10 +02:00
chore: [Signal] synchronize peer registry
This commit is contained in:
parent
8acddfd510
commit
06b0c46a5d
@ -3,6 +3,7 @@ package peer
|
|||||||
import (
|
import (
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/wiretrustee/wiretrustee/signal/proto"
|
"github.com/wiretrustee/wiretrustee/signal/proto"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Peer representation of a connected Peer
|
// Peer representation of a connected Peer
|
||||||
@ -25,32 +26,46 @@ func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer {
|
|||||||
// Registry registry that holds all currently connected Peers
|
// Registry registry that holds all currently connected Peers
|
||||||
type Registry struct {
|
type Registry struct {
|
||||||
// Peer.key -> Peer
|
// Peer.key -> Peer
|
||||||
Peers map[string]*Peer
|
Peers sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRegistry creates a new connected Peer registry
|
// NewRegistry creates a new connected Peer registry
|
||||||
func NewRegistry() *Registry {
|
func NewRegistry() *Registry {
|
||||||
return &Registry{
|
return &Registry{}
|
||||||
Peers: make(map[string]*Peer),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get gets a peer from the registry
|
||||||
|
func (registry *Registry) Get(peerId string) (*Peer, bool) {
|
||||||
|
if load, ok := registry.Peers.Load(peerId); ok {
|
||||||
|
return load.(*Peer), ok
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (registry *Registry) IsPeerRegistered(peerId string) bool {
|
||||||
|
if _, ok := registry.Peers.Load(peerId); ok {
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register registers peer in the registry
|
// Register registers peer in the registry
|
||||||
func (reg *Registry) Register(peer *Peer) {
|
func (registry *Registry) Register(peer *Peer) {
|
||||||
if _, exists := reg.Peers[peer.Id]; exists {
|
// can be that peer already exists but it is fine (e.g. reconnect)
|
||||||
log.Warnf("peer [%s] has been already registered", peer.Id)
|
// todo investigate what happens to the old peer (especially Peer.Stream) when we override it
|
||||||
} else {
|
registry.Peers.Store(peer.Id, peer)
|
||||||
log.Printf("registering new peer [%s]", peer.Id)
|
log.Printf("registered peer [%s]", peer.Id)
|
||||||
}
|
|
||||||
//replace Peer even if exists
|
|
||||||
//todo should we really replace?
|
|
||||||
reg.Peers[peer.Id] = peer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deregister deregister Peer from the Registry (usually once it disconnects)
|
// Deregister deregister Peer from the Registry (usually once it disconnects)
|
||||||
func (reg *Registry) Deregister(peer *Peer) {
|
func (registry *Registry) Deregister(peer *Peer) {
|
||||||
if _, ok := reg.Peers[peer.Id]; ok {
|
_, loaded := registry.Peers.LoadAndDelete(peer.Id)
|
||||||
delete(reg.Peers, peer.Id)
|
if loaded {
|
||||||
log.Printf("deregistered peer [%s]", peer.Id)
|
log.Printf("deregistered peer [%s]", peer.Id)
|
||||||
|
} else {
|
||||||
|
log.Warnf("attempted to remove non-existent peer [%s]", peer.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,20 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestRegistry_GetNonExistentPeer(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
|
||||||
|
peer, ok := r.Get("non_existent_peer")
|
||||||
|
|
||||||
|
if peer != nil {
|
||||||
|
t.Errorf("expected non_existent_peer not found in the registry")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
t.Errorf("expected non_existent_peer not found in the registry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRegistry_Register(t *testing.T) {
|
func TestRegistry_Register(t *testing.T) {
|
||||||
r := NewRegistry()
|
r := NewRegistry()
|
||||||
peer1 := NewPeer("test_peer_1", nil)
|
peer1 := NewPeer("test_peer_1", nil)
|
||||||
@ -11,15 +25,11 @@ func TestRegistry_Register(t *testing.T) {
|
|||||||
r.Register(peer1)
|
r.Register(peer1)
|
||||||
r.Register(peer2)
|
r.Register(peer2)
|
||||||
|
|
||||||
if len(r.Peers) != 2 {
|
if _, ok := r.Get("test_peer_1"); !ok {
|
||||||
t.Errorf("expected 2 registered peers")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := r.Peers["test_peer_1"]; !ok {
|
|
||||||
t.Errorf("expected test_peer_1 not found in the registry")
|
t.Errorf("expected test_peer_1 not found in the registry")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := r.Peers["test_peer_2"]; !ok {
|
if _, ok := r.Get("test_peer_2"); !ok {
|
||||||
t.Errorf("expected test_peer_2 not found in the registry")
|
t.Errorf("expected test_peer_2 not found in the registry")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -33,15 +43,11 @@ func TestRegistry_Deregister(t *testing.T) {
|
|||||||
|
|
||||||
r.Deregister(peer1)
|
r.Deregister(peer1)
|
||||||
|
|
||||||
if len(r.Peers) != 1 {
|
if _, ok := r.Get("test_peer_1"); ok {
|
||||||
t.Errorf("expected 1 registered peers after deregistring")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := r.Peers["test_peer_1"]; ok {
|
|
||||||
t.Errorf("expected test_peer_1 to absent in the registry after deregistering")
|
t.Errorf("expected test_peer_1 to absent in the registry after deregistering")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := r.Peers["test_peer_2"]; !ok {
|
if _, ok := r.Get("test_peer_2"); !ok {
|
||||||
t.Errorf("expected test_peer_2 not found in the registry")
|
t.Errorf("expected test_peer_2 not found in the registry")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,11 +27,11 @@ func NewServer() *SignalExchangeServer {
|
|||||||
// Send forwards a message to the signal peer
|
// Send forwards a message to the signal peer
|
||||||
func (s *SignalExchangeServer) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
func (s *SignalExchangeServer) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
|
||||||
if _, found := s.registry.Peers[msg.Key]; !found {
|
if !s.registry.IsPeerRegistered(msg.Key) {
|
||||||
return nil, fmt.Errorf("unknown peer %s", msg.Key)
|
return nil, fmt.Errorf("unknown peer %s", msg.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
if dstPeer, found := s.registry.Peers[msg.RemoteKey]; found {
|
if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
|
||||||
//forward the message to the target peer
|
//forward the message to the target peer
|
||||||
err := dstPeer.Stream.Send(msg)
|
err := dstPeer.Stream.Send(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -63,7 +63,7 @@ func (s *SignalExchangeServer) ConnectStream(stream proto.SignalExchange_Connect
|
|||||||
}
|
}
|
||||||
log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey)
|
log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey)
|
||||||
// lookup the target peer where the message is going to
|
// lookup the target peer where the message is going to
|
||||||
if dstPeer, found := s.registry.Peers[msg.RemoteKey]; found {
|
if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
|
||||||
//forward the message to the target peer
|
//forward the message to the target peer
|
||||||
err := dstPeer.Stream.Send(msg)
|
err := dstPeer.Stream.Send(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user