diff --git a/management/server/account.go b/management/server/account.go index cc5ca309a..ab1ffe8b3 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -17,6 +17,7 @@ import ( "time" cacheStore "github.com/eko/gocache/lib/v4/store" + "github.com/eko/gocache/store/redis/v4" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/vmihailenco/msgpack/v5" @@ -237,7 +238,7 @@ func BuildManager( if !isNil(am.idpManager) { go func() { - err := am.warmupIDPCache(ctx) + err := am.warmupIDPCache(ctx, cacheStore) if err != nil { log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) // todo retry? @@ -494,7 +495,25 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain return nil, status.Errorf(status.Internal, "error while creating new account") } -func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { +func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context, store cacheStore.StoreInterface) error { + cold, err := am.isCacheCold(ctx, store) + if err != nil { + return err + } + + if !cold { + log.WithContext(ctx).Debug("cache already populated, skipping warm up") + return nil + } + + if delayStr, ok := os.LookupEnv("NB_IDP_CACHE_WARMUP_DELAY"); ok { + delay, err := time.ParseDuration(delayStr) + if err != nil { + return fmt.Errorf("invalid IDP warmup delay: %w", err) + } + time.Sleep(delay) + } + userData, err := am.idpManager.GetAllAccounts(ctx) if err != nil { return err @@ -534,6 +553,32 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { return nil } +// isCacheCold checks if the cache needs warming up. +func (am *DefaultAccountManager) isCacheCold(ctx context.Context, store cacheStore.StoreInterface) (bool, error) { + if store.GetType() != redis.RedisType { + return true, nil + } + + accountID, err := am.Store.GetAnyAccountID(ctx) + if err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return true, nil + } + return false, err + } + + _, err = store.Get(ctx, accountID) + if err == nil { + return false, nil + } + + if notFoundErr := new(cacheStore.NotFound); errors.As(err, ¬FoundErr) { + return true, nil + } + + return false, fmt.Errorf("failed to check cache: %w", err) +} + // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) diff --git a/management/server/account_test.go b/management/server/account_test.go index 7f34cf845..fe082d9a0 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -14,30 +14,30 @@ import ( "time" "github.com/golang/mock/gomock" - - nbAccount "github.com/netbirdio/netbird/management/server/account" - "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/util" - - resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" - networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - + "github.com/netbirdio/netbird/management/server/idp" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" + nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -3201,3 +3201,53 @@ func Test_UpdateToPrimaryAccount(t *testing.T) { assert.NoError(t, err) assert.True(t, account.IsDomainPrimaryAccount) } + +func TestDefaultAccountManager_IsCacheCold(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err) + + t.Run("memory cache", func(t *testing.T) { + t.Run("should always return true", func(t *testing.T) { + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + require.NoError(t, err) + + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.True(t, cold) + }) + }) + + t.Run("redis cache", func(t *testing.T) { + cleanup, redisURL, err := testutil.CreateRedisTestContainer() + require.NoError(t, err) + t.Cleanup(cleanup) + t.Setenv(cache.RedisStoreEnvVar, redisURL) + + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + require.NoError(t, err) + + t.Run("should return true when no account exists", func(t *testing.T) { + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.True(t, cold) + }) + + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + require.NoError(t, err) + + t.Run("should return true when account is not found in cache", func(t *testing.T) { + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.True(t, cold) + }) + + t.Run("should return false when account is found in cache", func(t *testing.T) { + err = cacheStore.Set(context.Background(), account.Id, &idp.UserData{ID: "v", Name: "vv"}) + require.NoError(t, err) + + cold, err := manager.isCacheCold(context.Background(), cacheStore) + assert.NoError(t, err) + assert.False(t, cold) + }) + }) +} diff --git a/management/server/cache/idp_test.go b/management/server/cache/idp_test.go index beefcd9bd..3fcfbb11a 100644 --- a/management/server/cache/idp_test.go +++ b/management/server/cache/idp_test.go @@ -8,12 +8,11 @@ import ( "github.com/eko/gocache/lib/v4/store" "github.com/redis/go-redis/v9" - "github.com/testcontainers/testcontainers-go" - testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/vmihailenco/msgpack/v5" "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/testutil" ) func TestNewIDPCacheManagers(t *testing.T) { @@ -27,21 +26,11 @@ func TestNewIDPCacheManagers(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { if tc.redis { - ctx := context.Background() - redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7")) + cleanup, redisURL, err := testutil.CreateRedisTestContainer() if err != nil { t.Fatalf("couldn't start redis container: %s", err) } - defer func() { - if err := redisContainer.Terminate(ctx); err != nil { - t.Logf("failed to terminate container: %s", err) - } - }() - redisURL, err := redisContainer.ConnectionString(ctx) - if err != nil { - t.Fatalf("couldn't get connection string: %s", err) - } - + t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) } cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index b73c372ae..7d3b288e0 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -800,6 +800,19 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( return s.GetAccount(ctx, peer.AccountID) } +func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) { + var account types.Account + result := s.db.WithContext(ctx).Select("id").Limit(1).Find(&account) + if result.Error != nil { + return "", status.NewGetAccountFromStoreError(result.Error) + } + if result.RowsAffected == 0 { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return account.Id, nil +} + func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index c16a50108..8bd8ce098 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3263,3 +3263,28 @@ func TestSqlStore_GetAccountMeta(t *testing.T) { require.Equal(t, "private", accountMeta.DomainCategory) require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC()) } + +func TestSqlStore_GetAnyAccountID(t *testing.T) { + t.Run("should return account ID when accounts exist", func(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID, err := store.GetAnyAccountID(context.Background()) + require.NoError(t, err) + assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", accountID) + }) + + t.Run("should return error when no accounts exist", func(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID, err := store.GetAnyAccountID(context.Background()) + require.Error(t, err) + sErr, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, sErr.Type(), status.NotFound) + assert.Empty(t, accountID) + }) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 4a26bf5c3..ca332a493 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -55,6 +55,7 @@ type Store interface { GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) + GetAnyAccountID(ctx context.Context) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) diff --git a/management/server/testutil/store.go b/management/server/testutil/store.go index 8672efa7f..ca022bfef 100644 --- a/management/server/testutil/store.go +++ b/management/server/testutil/store.go @@ -12,6 +12,7 @@ import ( "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/modules/postgres" + testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/testcontainers/testcontainers-go/wait" ) @@ -84,3 +85,28 @@ func CreatePostgresTestContainer() (func(), error) { return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn) } + +// CreateRedisTestContainer creates a new Redis container for testing. +func CreateRedisTestContainer() (func(), string, error) { + ctx := context.Background() + + redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7")) + if err != nil { + return nil, "", err + } + + cleanup := func() { + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) + defer cancelFunc() + if err = redisContainer.Terminate(timeoutCtx); err != nil { + log.WithContext(ctx).Warnf("failed to stop redis container %s: %s", redisContainer.GetContainerID(), err) + } + } + + redisURL, err := redisContainer.ConnectionString(ctx) + if err != nil { + return nil, "", err + } + + return cleanup, redisURL, nil +} diff --git a/management/server/testutil/store_ios.go b/management/server/testutil/store_ios.go index edde62f1e..a614258d2 100644 --- a/management/server/testutil/store_ios.go +++ b/management/server/testutil/store_ios.go @@ -14,3 +14,9 @@ func CreateMysqlTestContainer() (func(), error) { // Empty function for MySQL }, nil } + +func CreateRedisTestContainer() (func(), string, error) { + return func() { + // Empty function for Redis + }, "", nil +}