diff --git a/management/server/account_test.go b/management/server/account_test.go index 97e0d45f0..fb1f39868 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -18,7 +18,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -28,47 +27,6 @@ import ( "github.com/netbirdio/netbird/route" ) -type MocIntegratedValidator struct { - ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) -} - -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { - return nil -} - -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { - if a.ValidatePeerFunc != nil { - return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) - } - return update, false, nil -} -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { - validatedPeers := make(map[string]struct{}) - for _, peer := range peers { - validatedPeers[peer.ID] = struct{}{} - } - return validatedPeers, nil -} - -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { - return peer -} - -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { - return false, false, nil -} - -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { - return nil -} - -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { - -} - -func (MocIntegratedValidator) Stop(_ context.Context) { -} - func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { t.Helper() peer := &nbpeer.Peer{ @@ -1038,7 +996,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { } b.Run("public without account ID", func(b *testing.B) { - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims) if err != nil { @@ -1048,7 +1006,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { }) b.Run("private without account ID", func(b *testing.B) { - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) if err != nil { @@ -1059,7 +1017,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { b.Run("private with account ID", func(b *testing.B) { claims.AccountId = id - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) if err != nil { diff --git a/management/server/http/setupkeys_integration_test.go b/management/server/http/setupkeys_integration_test.go index e32ea225e..eb3fc1b01 100644 --- a/management/server/http/setupkeys_integration_test.go +++ b/management/server/http/setupkeys_integration_test.go @@ -25,6 +25,7 @@ import ( const ( testAccountId = "testUserId" testUserId = "testAccountId" + testPeerId = "testPeerId" newKeyName = "newKey" expiresIn = 3600 @@ -82,7 +83,14 @@ func Test_SetupKeys_Create_Success(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) - peersUpdateManager := &server.PeersUpdateManager{} + peersUpdateManager := server.NewPeersUpdateManager(nil) + updMsg := peersUpdateManager.CreateChannel(context.Background(), testPeerId) + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + geoMock := &geolocation.GeolocationMock{} validatorMock := server.MocIntegratedValidator{} am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics) @@ -133,6 +141,12 @@ func Test_SetupKeys_Create_Success(t *testing.T) { } validateCreatedKey(t, tc.expectedSetupKey, toResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } }) } } @@ -162,3 +176,27 @@ func validateCreatedKey(t *testing.T, expectedKey *api.SetupKey, got *api.SetupK assert.Equal(t, expectedKey, got) } + +func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *server.UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + assert.Equal(t, expected, msg) + case <-time.After(500 * time.Millisecond): + t.Error("Timed out waiting for update message") + } +} diff --git a/management/server/management_test.go b/management/server/management_test.go index 5361da53f..cc10449c8 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -21,10 +21,7 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" ) @@ -446,43 +443,6 @@ var _ = Describe("Management service", func() { }) }) -type MocIntegratedValidator struct { -} - -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { - return nil -} - -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { - return update, false, nil -} - -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { - validatedPeers := make(map[string]struct{}) - for p := range peers { - validatedPeers[p] = struct{}{} - } - return validatedPeers, nil -} - -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { - return peer -} - -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { - return false, false, nil -} - -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { - return nil -} - -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { - -} - -func (MocIntegratedValidator) Stop(_ context.Context) {} - func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -545,7 +505,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. log.Fatalf("failed creating metrics: %v", err) } - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, server.MocIntegratedValidator{}, metrics) if err != nil { log.Fatalf("failed creating a manager: %v", err) }