diff --git a/management/server/account_test.go b/management/server/account_test.go index e6c9b60da..34532937b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2374,3 +2374,24 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpee return manager, account, peer1, peer2, peer3 } + +func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(100 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { + t.Helper() + + select { + case <-updateMessage: + return + case <-time.After(100 * time.Millisecond): + t.Errorf("timed out waiting for update message") + } +} diff --git a/management/server/policy_test.go b/management/server/policy_test.go index bf9a53d16..1e1b15edb 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "testing" + "time" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" @@ -824,3 +825,117 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int { return 0 // a is equal to b } } + +func TestPolicyAccountPeerUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "group-id", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + policy := Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{}, + Destinations: []string{}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + // Saving policy with empty rule groups should not update account peers and not send peer update + t.Run("saving policy with empty rule groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // updating policy with rule groups should update account peers and send peer update + t.Run("updating policy with rule groups", func(t *testing.T) { + policy.Rules = []*PolicyRule{ + { + Enabled: true, + Sources: []string{"group-id"}, + Destinations: []string{"group-id"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Re-saving unchanged policy should trigger account peers update and not send peer update + // since there is no change in the network map + t.Run("re-saving unchanged policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting policy should trigger account peers update and send peer update + t.Run("deleting policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + + }) +}