mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
add session id to update channel
This commit is contained in:
parent
b952d8693d
commit
22126d0484
@ -1147,14 +1147,14 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
message := <-updMsg.channel
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
@ -1174,14 +1174,14 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
message := <-updMsg.channel
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
@ -1210,7 +1210,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
@ -1230,7 +1230,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
message := <-updMsg.channel
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
@ -1277,14 +1277,14 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
message := <-updMsg.channel
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 1 {
|
||||
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
|
||||
@ -1303,7 +1303,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
|
||||
group := group.Group{
|
||||
ID: "groupA",
|
||||
@ -1339,7 +1339,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
message := <-updMsg.channel
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
|
@ -499,14 +499,14 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
})
|
||||
|
||||
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates
|
||||
t.Run("saving dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -526,7 +526,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -559,7 +559,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -585,7 +585,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -605,7 +605,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -625,7 +625,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing group with peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -418,14 +418,14 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
})
|
||||
|
||||
// Saving a group that is not linked to any resource should not update account peers
|
||||
t.Run("saving unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -448,7 +448,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -467,7 +467,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing peer from unliked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -485,7 +485,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -519,7 +519,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving linked group to policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -541,7 +541,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -559,7 +559,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing peer from linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -588,7 +588,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -629,7 +629,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -194,31 +194,31 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, peerUpdates *PeerUpdateChannel, srv proto.ManagementService_SyncServer) error {
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
case update, open := <-updates:
|
||||
case update, open := <-peerUpdates.channel:
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
|
||||
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(peerUpdates.channel) + 1)
|
||||
}
|
||||
|
||||
if !open {
|
||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, peerUpdates.sessionID)
|
||||
return nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, peerUpdates.sessionID, update, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// condition when client <-> server connection has been terminated
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
log.WithContext(ctx).Debugf("stream of peer %s with session %s has been closed", peerKey.String(), peerUpdates.sessionID)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, peerUpdates.sessionID)
|
||||
return srv.Context().Err()
|
||||
}
|
||||
}
|
||||
@ -226,10 +226,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, sessionID string, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, sessionID)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
@ -237,18 +237,22 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer, sessionID)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.secretsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, sessionID string) {
|
||||
|
||||
bool1 := s.peersUpdateManager.CloseChannel(ctx, peer.ID, sessionID)
|
||||
if bool1 {
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
|
||||
s.secretsManager.CancelRefresh(sessionID)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
|
@ -960,7 +960,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
})
|
||||
|
||||
// Creating a nameserver group with a distribution group no peers should not update account peers
|
||||
@ -968,7 +968,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -995,7 +995,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1013,7 +1013,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1039,7 +1039,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1069,7 +1069,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting nameserver group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -313,7 +313,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
|
||||
},
|
||||
NetworkMap: &NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID, SessionIdForceOverwrite)
|
||||
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
|
@ -864,10 +864,14 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
peerChannels := make(map[string]*PeerUpdateChannel)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
peerChannels[peerID] = &PeerUpdateChannel{
|
||||
peerID: peerID,
|
||||
channel: make(chan *UpdateMessage, channelBufferSize),
|
||||
sessionID: xid.New().String(),
|
||||
}
|
||||
}
|
||||
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
@ -1315,14 +1319,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
})
|
||||
|
||||
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
|
||||
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1340,7 +1344,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1365,7 +1369,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1383,7 +1387,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating peer label", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1417,7 +1421,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1443,7 +1447,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with linked group to policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1481,7 +1485,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1507,7 +1511,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1536,7 +1540,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1562,7 +1566,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with linked group to route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -856,7 +856,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
})
|
||||
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
@ -878,7 +878,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -913,7 +913,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -948,7 +948,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -982,7 +982,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1016,7 +1016,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1051,7 +1051,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1085,7 +1085,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1105,7 +1105,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1126,7 +1126,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
policyID := "policy-destination-has-peers-source-none"
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1145,7 +1145,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
policyID := "policy-rule-groups-no-peers"
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -147,7 +147,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
})
|
||||
|
||||
postureCheck := posture.Checks{
|
||||
@ -165,7 +165,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -183,7 +183,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -222,7 +222,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("linking posture check to policy with peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -251,7 +251,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -269,7 +269,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("removing posture check from policy", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -289,7 +289,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting unused posture check", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -328,7 +328,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -352,7 +352,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
|
||||
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID, updMsg1.sessionID)
|
||||
})
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
@ -375,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg1)
|
||||
peerShouldReceiveUpdate(t, updMsg1.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -416,7 +416,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -1807,7 +1807,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID, updMsg.sessionID)
|
||||
})
|
||||
|
||||
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update
|
||||
@ -1827,7 +1827,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1863,7 +1863,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1899,7 +1899,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating route with a routing peer", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1924,7 +1924,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1942,7 +1942,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting route", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1978,7 +1978,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -2018,7 +2018,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -408,7 +408,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
})
|
||||
|
||||
var setupKey *SetupKey
|
||||
@ -417,7 +417,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("creating setup key", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -435,7 +435,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving setup key", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
@ -104,7 +104,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
loop:
|
||||
for timeout := time.After(5 * time.Second); ; {
|
||||
select {
|
||||
case update := <-updateChannel:
|
||||
case update := <-updateChannel.channel:
|
||||
updates = append(updates, update)
|
||||
case <-timeout:
|
||||
break loop
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@ -12,15 +13,22 @@ import (
|
||||
)
|
||||
|
||||
const channelBufferSize = 100
|
||||
const SessionIdForceOverwrite = "FORCE"
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *NetworkMap
|
||||
}
|
||||
|
||||
type PeerUpdateChannel struct {
|
||||
peerID string
|
||||
sessionID string
|
||||
channel chan *UpdateMessage
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
// peerChannels is a map of peerID to the channel used to deliver updates relevant to the peer
|
||||
peerChannels map[string]*PeerUpdateChannel
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.RWMutex
|
||||
// metrics provides method to collect application metrics
|
||||
@ -30,7 +38,7 @@ type PeersUpdateManager struct {
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
peerChannels: make(map[string]*PeerUpdateChannel),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
}
|
||||
@ -50,14 +58,14 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
}()
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
if peerUpdates, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
select {
|
||||
case channel <- update:
|
||||
case peerUpdates.channel <- update:
|
||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
||||
default:
|
||||
dropped = true
|
||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(peerUpdates.channel))
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("peer %s has no channel", peerID)
|
||||
@ -65,7 +73,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
|
||||
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) *PeerUpdateChannel {
|
||||
start := time.Now()
|
||||
|
||||
closed := false
|
||||
@ -81,24 +89,39 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
closed = true
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
close(channel.channel)
|
||||
log.WithContext(ctx).Debugf("overwriting existing channel for peer %s", peerID)
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
p.peerChannels[peerID] = channel
|
||||
|
||||
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
|
||||
peerUpdateChannel := &PeerUpdateChannel{
|
||||
peerID: peerID,
|
||||
sessionID: uuid.New().String(),
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel: make(chan *UpdateMessage, channelBufferSize),
|
||||
}
|
||||
|
||||
return channel
|
||||
p.peerChannels[peerID] = peerUpdateChannel
|
||||
|
||||
log.WithContext(ctx).Debugf("opened updates channel for a peer %s and session %s", peerID, peerUpdateChannel.sessionID)
|
||||
|
||||
return peerUpdateChannel
|
||||
}
|
||||
|
||||
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string, sessionID string) bool {
|
||||
if peerUpdates, ok := p.peerChannels[peerID]; ok {
|
||||
if peerUpdates.sessionID == sessionID || sessionID == SessionIdForceOverwrite {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(peerUpdates.channel)
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s and session %s", peerID, sessionID)
|
||||
return true
|
||||
}
|
||||
log.WithContext(ctx).Warnf("tried to close updates channel of a peer %s with session %s, but current session is %s", peerID, sessionID, peerUpdates.sessionID)
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
log.WithContext(ctx).Warnf("tried to close updates channel of a peer %s with session %s, but no channel found", peerID, sessionID)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CloseChannels closes updates channel for each given peer
|
||||
@ -114,12 +137,12 @@ func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string
|
||||
}()
|
||||
|
||||
for _, id := range peerIDs {
|
||||
p.closeChannel(ctx, id)
|
||||
p.closeChannel(ctx, id, SessionIdForceOverwrite)
|
||||
}
|
||||
}
|
||||
|
||||
// CloseChannel closes updates channel of a given peer
|
||||
func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
|
||||
func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string, sessionID string) bool {
|
||||
start := time.Now()
|
||||
|
||||
p.channelsMux.Lock()
|
||||
@ -130,7 +153,7 @@ func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
|
||||
}
|
||||
}()
|
||||
|
||||
p.closeChannel(ctx, peerID)
|
||||
return p.closeChannel(ctx, peerID, sessionID)
|
||||
}
|
||||
|
||||
// GetAllConnectedPeers returns a copy of the connected peers map
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
@ -13,7 +15,7 @@ import (
|
||||
func TestCreateChannel(t *testing.T) {
|
||||
peer := "test-create"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
defer peersUpdater.CloseChannel(context.Background(), peer)
|
||||
defer peersUpdater.CloseChannel(context.Background(), peer, "sessionID")
|
||||
|
||||
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
@ -35,7 +37,7 @@ func TestSendUpdate(t *testing.T) {
|
||||
}
|
||||
peersUpdater.SendUpdate(context.Background(), peer, update1)
|
||||
select {
|
||||
case <-peersUpdater.peerChannels[peer]:
|
||||
case <-peersUpdater.peerChannels[peer].channel:
|
||||
default:
|
||||
t.Error("Update wasn't send")
|
||||
}
|
||||
@ -56,7 +58,7 @@ func TestSendUpdate(t *testing.T) {
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Error("timed out reading previously sent updates")
|
||||
case updateReader := <-peersUpdater.peerChannels[peer]:
|
||||
case updateReader := <-peersUpdater.peerChannels[peer].channel:
|
||||
if updateReader.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial {
|
||||
t.Error("got the update that shouldn't have been sent")
|
||||
}
|
||||
@ -65,15 +67,50 @@ func TestSendUpdate(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestCloseChannel(t *testing.T) {
|
||||
func TestCloseChannel_WithCorrectSessionID(t *testing.T) {
|
||||
peer := "test-close"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
_ = peersUpdater.CreateChannel(context.Background(), peer)
|
||||
peerUpdates := peersUpdater.CreateChannel(context.Background(), peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Error creating the channel")
|
||||
}
|
||||
peersUpdater.CloseChannel(context.Background(), peer)
|
||||
|
||||
updateDB := peersUpdater.CloseChannel(context.Background(), peer, peerUpdates.sessionID)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; ok {
|
||||
t.Error("Error closing the channel")
|
||||
}
|
||||
|
||||
assert.Equal(t, true, updateDB)
|
||||
}
|
||||
|
||||
func TestCloseChannel_WithWrongSessionID(t *testing.T) {
|
||||
peer := "test-close"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
peersUpdater.CreateChannel(context.Background(), peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Error creating the channel")
|
||||
}
|
||||
|
||||
updateDB := peersUpdater.CloseChannel(context.Background(), peer, "wrongSessionID")
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Should not close channel with wrong session id")
|
||||
}
|
||||
|
||||
assert.Equal(t, false, updateDB)
|
||||
}
|
||||
|
||||
func TestCloseChannel_WithForceOverwrite(t *testing.T) {
|
||||
peer := "test-close"
|
||||
peersUpdater := NewPeersUpdateManager(nil)
|
||||
peersUpdater.CreateChannel(context.Background(), peer)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; !ok {
|
||||
t.Error("Error creating the channel")
|
||||
}
|
||||
|
||||
updateDB := peersUpdater.CloseChannel(context.Background(), peer, SessionIdForceOverwrite)
|
||||
if _, ok := peersUpdater.peerChannels[peer]; ok {
|
||||
t.Error("Should close channel if forced")
|
||||
}
|
||||
|
||||
assert.Equal(t, true, updateDB)
|
||||
}
|
||||
|
@ -10,13 +10,14 @@ import (
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
@ -1297,14 +1298,14 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
|
||||
})
|
||||
|
||||
// Creating a new regular user should not update account peers and not send peer update
|
||||
t.Run("creating new regular user with no groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1327,7 +1328,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating user with no linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1350,7 +1351,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting user with no linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1387,7 +1388,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating user with linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -1408,14 +1409,14 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID)
|
||||
t.Cleanup(func() {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID)
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID, peer4UpdMsg.sessionID)
|
||||
})
|
||||
|
||||
// deleting user with linked peers should update account peers and send peer update
|
||||
t.Run("deleting user with linked peers", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, peer4UpdMsg)
|
||||
peerShouldReceiveUpdate(t, peer4UpdMsg.channel)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user