[management] Skip IdP cache warm-up on Redis if data exists (#3733)

* Add Redis cache check to skip warm-up on startup if cache is already populated
* Refactor Redis test container setup for reusability
This commit is contained in:
Bethuel Mmbaga 2025-04-28 15:10:40 +03:00 committed by GitHub
parent 3fa915e271
commit d8dc107bee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 182 additions and 27 deletions

View File

@ -17,6 +17,7 @@ import (
"time" "time"
cacheStore "github.com/eko/gocache/lib/v4/store" cacheStore "github.com/eko/gocache/lib/v4/store"
"github.com/eko/gocache/store/redis/v4"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
@ -237,7 +238,7 @@ func BuildManager(
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
go func() { go func() {
err := am.warmupIDPCache(ctx) err := am.warmupIDPCache(ctx, cacheStore)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err)
// todo retry? // todo retry?
@ -494,7 +495,25 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
return nil, status.Errorf(status.Internal, "error while creating new account") return nil, status.Errorf(status.Internal, "error while creating new account")
} }
func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context, store cacheStore.StoreInterface) error {
cold, err := am.isCacheCold(ctx, store)
if err != nil {
return err
}
if !cold {
log.WithContext(ctx).Debug("cache already populated, skipping warm up")
return nil
}
if delayStr, ok := os.LookupEnv("NB_IDP_CACHE_WARMUP_DELAY"); ok {
delay, err := time.ParseDuration(delayStr)
if err != nil {
return fmt.Errorf("invalid IDP warmup delay: %w", err)
}
time.Sleep(delay)
}
userData, err := am.idpManager.GetAllAccounts(ctx) userData, err := am.idpManager.GetAllAccounts(ctx)
if err != nil { if err != nil {
return err return err
@ -534,6 +553,32 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
return nil return nil
} }
// isCacheCold checks if the cache needs warming up.
func (am *DefaultAccountManager) isCacheCold(ctx context.Context, store cacheStore.StoreInterface) (bool, error) {
if store.GetType() != redis.RedisType {
return true, nil
}
accountID, err := am.Store.GetAnyAccountID(ctx)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return true, nil
}
return false, err
}
_, err = store.Get(ctx, accountID)
if err == nil {
return false, nil
}
if notFoundErr := new(cacheStore.NotFound); errors.As(err, &notFoundErr) {
return true, nil
}
return false, fmt.Errorf("failed to check cache: %w", err)
}
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)

View File

@ -14,30 +14,30 @@ import (
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/management/server/idp"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/util"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -3201,3 +3201,53 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount) assert.True(t, account.IsDomainPrimaryAccount)
} }
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err)
t.Run("memory cache", func(t *testing.T) {
t.Run("should always return true", func(t *testing.T) {
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
require.NoError(t, err)
cold, err := manager.isCacheCold(context.Background(), cacheStore)
assert.NoError(t, err)
assert.True(t, cold)
})
})
t.Run("redis cache", func(t *testing.T) {
cleanup, redisURL, err := testutil.CreateRedisTestContainer()
require.NoError(t, err)
t.Cleanup(cleanup)
t.Setenv(cache.RedisStoreEnvVar, redisURL)
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
require.NoError(t, err)
t.Run("should return true when no account exists", func(t *testing.T) {
cold, err := manager.isCacheCold(context.Background(), cacheStore)
assert.NoError(t, err)
assert.True(t, cold)
})
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
require.NoError(t, err)
t.Run("should return true when account is not found in cache", func(t *testing.T) {
cold, err := manager.isCacheCold(context.Background(), cacheStore)
assert.NoError(t, err)
assert.True(t, cold)
})
t.Run("should return false when account is found in cache", func(t *testing.T) {
err = cacheStore.Set(context.Background(), account.Id, &idp.UserData{ID: "v", Name: "vv"})
require.NoError(t, err)
cold, err := manager.isCacheCold(context.Background(), cacheStore)
assert.NoError(t, err)
assert.False(t, cold)
})
})
}

View File

@ -8,12 +8,11 @@ import (
"github.com/eko/gocache/lib/v4/store" "github.com/eko/gocache/lib/v4/store"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"github.com/testcontainers/testcontainers-go"
testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
"github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/testutil"
) )
func TestNewIDPCacheManagers(t *testing.T) { func TestNewIDPCacheManagers(t *testing.T) {
@ -27,21 +26,11 @@ func TestNewIDPCacheManagers(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.redis { if tc.redis {
ctx := context.Background() cleanup, redisURL, err := testutil.CreateRedisTestContainer()
redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7"))
if err != nil { if err != nil {
t.Fatalf("couldn't start redis container: %s", err) t.Fatalf("couldn't start redis container: %s", err)
} }
defer func() { t.Cleanup(cleanup)
if err := redisContainer.Terminate(ctx); err != nil {
t.Logf("failed to terminate container: %s", err)
}
}()
redisURL, err := redisContainer.ConnectionString(ctx)
if err != nil {
t.Fatalf("couldn't get connection string: %s", err)
}
t.Setenv(cache.RedisStoreEnvVar, redisURL) t.Setenv(cache.RedisStoreEnvVar, redisURL)
} }
cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval) cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval)

View File

@ -800,6 +800,19 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
return s.GetAccount(ctx, peer.AccountID) return s.GetAccount(ctx, peer.AccountID)
} }
func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) {
var account types.Account
result := s.db.WithContext(ctx).Select("id").Limit(1).Find(&account)
if result.Error != nil {
return "", status.NewGetAccountFromStoreError(result.Error)
}
if result.RowsAffected == 0 {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return account.Id, nil
}
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
var accountID string var accountID string

View File

@ -3263,3 +3263,28 @@ func TestSqlStore_GetAccountMeta(t *testing.T) {
require.Equal(t, "private", accountMeta.DomainCategory) require.Equal(t, "private", accountMeta.DomainCategory)
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC()) require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC())
} }
func TestSqlStore_GetAnyAccountID(t *testing.T) {
t.Run("should return account ID when accounts exist", func(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID, err := store.GetAnyAccountID(context.Background())
require.NoError(t, err)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", accountID)
})
t.Run("should return error when no accounts exist", func(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID, err := store.GetAnyAccountID(context.Background())
require.Error(t, err)
sErr, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, sErr.Type(), status.NotFound)
assert.Empty(t, accountID)
})
}

View File

@ -55,6 +55,7 @@ type Store interface {
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error)
GetAnyAccountID(ctx context.Context) (string, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)

View File

@ -12,6 +12,7 @@ import (
"github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/modules/mysql"
"github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/modules/postgres"
testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis"
"github.com/testcontainers/testcontainers-go/wait" "github.com/testcontainers/testcontainers-go/wait"
) )
@ -84,3 +85,28 @@ func CreatePostgresTestContainer() (func(), error) {
return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn) return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn)
} }
// CreateRedisTestContainer creates a new Redis container for testing.
func CreateRedisTestContainer() (func(), string, error) {
ctx := context.Background()
redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7"))
if err != nil {
return nil, "", err
}
cleanup := func() {
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
defer cancelFunc()
if err = redisContainer.Terminate(timeoutCtx); err != nil {
log.WithContext(ctx).Warnf("failed to stop redis container %s: %s", redisContainer.GetContainerID(), err)
}
}
redisURL, err := redisContainer.ConnectionString(ctx)
if err != nil {
return nil, "", err
}
return cleanup, redisURL, nil
}

View File

@ -14,3 +14,9 @@ func CreateMysqlTestContainer() (func(), error) {
// Empty function for MySQL // Empty function for MySQL
}, nil }, nil
} }
func CreateRedisTestContainer() (func(), string, error) {
return func() {
// Empty function for Redis
}, "", nil
}