mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 09:47:49 +02:00
[management] Skip IdP cache warm-up on Redis if data exists (#3733)
* Add Redis cache check to skip warm-up on startup if cache is already populated * Refactor Redis test container setup for reusability
This commit is contained in:
parent
3fa915e271
commit
d8dc107bee
@ -17,6 +17,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
cacheStore "github.com/eko/gocache/lib/v4/store"
|
cacheStore "github.com/eko/gocache/lib/v4/store"
|
||||||
|
"github.com/eko/gocache/store/redis/v4"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/vmihailenco/msgpack/v5"
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
@ -237,7 +238,7 @@ func BuildManager(
|
|||||||
|
|
||||||
if !isNil(am.idpManager) {
|
if !isNil(am.idpManager) {
|
||||||
go func() {
|
go func() {
|
||||||
err := am.warmupIDPCache(ctx)
|
err := am.warmupIDPCache(ctx, cacheStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err)
|
log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err)
|
||||||
// todo retry?
|
// todo retry?
|
||||||
@ -494,7 +495,25 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
|
|||||||
return nil, status.Errorf(status.Internal, "error while creating new account")
|
return nil, status.Errorf(status.Internal, "error while creating new account")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
|
func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context, store cacheStore.StoreInterface) error {
|
||||||
|
cold, err := am.isCacheCold(ctx, store)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cold {
|
||||||
|
log.WithContext(ctx).Debug("cache already populated, skipping warm up")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if delayStr, ok := os.LookupEnv("NB_IDP_CACHE_WARMUP_DELAY"); ok {
|
||||||
|
delay, err := time.ParseDuration(delayStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid IDP warmup delay: %w", err)
|
||||||
|
}
|
||||||
|
time.Sleep(delay)
|
||||||
|
}
|
||||||
|
|
||||||
userData, err := am.idpManager.GetAllAccounts(ctx)
|
userData, err := am.idpManager.GetAllAccounts(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -534,6 +553,32 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isCacheCold checks if the cache needs warming up.
|
||||||
|
func (am *DefaultAccountManager) isCacheCold(ctx context.Context, store cacheStore.StoreInterface) (bool, error) {
|
||||||
|
if store.GetType() != redis.RedisType {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, err := am.Store.GetAnyAccountID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = store.Get(ctx, accountID)
|
||||||
|
if err == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if notFoundErr := new(cacheStore.NotFound); errors.As(err, ¬FoundErr) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, fmt.Errorf("failed to check cache: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
|
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
|
||||||
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
@ -14,30 +14,30 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
|
||||||
|
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/management/server/testutil"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3201,3 +3201,53 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, account.IsDomainPrimaryAccount)
|
assert.True(t, account.IsDomainPrimaryAccount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("memory cache", func(t *testing.T) {
|
||||||
|
t.Run("should always return true", func(t *testing.T) {
|
||||||
|
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cold, err := manager.isCacheCold(context.Background(), cacheStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, cold)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("redis cache", func(t *testing.T) {
|
||||||
|
cleanup, redisURL, err := testutil.CreateRedisTestContainer()
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
t.Setenv(cache.RedisStoreEnvVar, redisURL)
|
||||||
|
|
||||||
|
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("should return true when no account exists", func(t *testing.T) {
|
||||||
|
cold, err := manager.isCacheCold(context.Background(), cacheStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, cold)
|
||||||
|
})
|
||||||
|
|
||||||
|
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("should return true when account is not found in cache", func(t *testing.T) {
|
||||||
|
cold, err := manager.isCacheCold(context.Background(), cacheStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, cold)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should return false when account is found in cache", func(t *testing.T) {
|
||||||
|
err = cacheStore.Set(context.Background(), account.Id, &idp.UserData{ID: "v", Name: "vv"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cold, err := manager.isCacheCold(context.Background(), cacheStore)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, cold)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
17
management/server/cache/idp_test.go
vendored
17
management/server/cache/idp_test.go
vendored
@ -8,12 +8,11 @@ import (
|
|||||||
|
|
||||||
"github.com/eko/gocache/lib/v4/store"
|
"github.com/eko/gocache/lib/v4/store"
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
"github.com/testcontainers/testcontainers-go"
|
|
||||||
testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
|
||||||
"github.com/vmihailenco/msgpack/v5"
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/cache"
|
"github.com/netbirdio/netbird/management/server/cache"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/management/server/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewIDPCacheManagers(t *testing.T) {
|
func TestNewIDPCacheManagers(t *testing.T) {
|
||||||
@ -27,21 +26,11 @@ func TestNewIDPCacheManagers(t *testing.T) {
|
|||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
if tc.redis {
|
if tc.redis {
|
||||||
ctx := context.Background()
|
cleanup, redisURL, err := testutil.CreateRedisTestContainer()
|
||||||
redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("couldn't start redis container: %s", err)
|
t.Fatalf("couldn't start redis container: %s", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
t.Cleanup(cleanup)
|
||||||
if err := redisContainer.Terminate(ctx); err != nil {
|
|
||||||
t.Logf("failed to terminate container: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
redisURL, err := redisContainer.ConnectionString(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("couldn't get connection string: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Setenv(cache.RedisStoreEnvVar, redisURL)
|
t.Setenv(cache.RedisStoreEnvVar, redisURL)
|
||||||
}
|
}
|
||||||
cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval)
|
cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval)
|
||||||
|
@ -800,6 +800,19 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
|||||||
return s.GetAccount(ctx, peer.AccountID)
|
return s.GetAccount(ctx, peer.AccountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) {
|
||||||
|
var account types.Account
|
||||||
|
result := s.db.WithContext(ctx).Select("id").Limit(1).Find(&account)
|
||||||
|
if result.Error != nil {
|
||||||
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.Id, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
var accountID string
|
var accountID string
|
||||||
|
@ -3263,3 +3263,28 @@ func TestSqlStore_GetAccountMeta(t *testing.T) {
|
|||||||
require.Equal(t, "private", accountMeta.DomainCategory)
|
require.Equal(t, "private", accountMeta.DomainCategory)
|
||||||
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC())
|
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetAnyAccountID(t *testing.T) {
|
||||||
|
t.Run("should return account ID when accounts exist", func(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID, err := store.GetAnyAccountID(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", accountID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should return error when no accounts exist", func(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID, err := store.GetAnyAccountID(context.Background())
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
assert.Empty(t, accountID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -55,6 +55,7 @@ type Store interface {
|
|||||||
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||||
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
|
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
|
||||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error)
|
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error)
|
||||||
|
GetAnyAccountID(ctx context.Context) (string, error)
|
||||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||||
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
|
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
|
||||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/testcontainers/testcontainers-go"
|
"github.com/testcontainers/testcontainers-go"
|
||||||
"github.com/testcontainers/testcontainers-go/modules/mysql"
|
"github.com/testcontainers/testcontainers-go/modules/mysql"
|
||||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||||
|
testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||||
"github.com/testcontainers/testcontainers-go/wait"
|
"github.com/testcontainers/testcontainers-go/wait"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -84,3 +85,28 @@ func CreatePostgresTestContainer() (func(), error) {
|
|||||||
|
|
||||||
return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn)
|
return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateRedisTestContainer creates a new Redis container for testing.
|
||||||
|
func CreateRedisTestContainer() (func(), string, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
redisContainer, err := testcontainersredis.RunContainer(ctx, testcontainers.WithImage("redis:7"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
|
||||||
|
defer cancelFunc()
|
||||||
|
if err = redisContainer.Terminate(timeoutCtx); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to stop redis container %s: %s", redisContainer.GetContainerID(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
redisURL, err := redisContainer.ConnectionString(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleanup, redisURL, nil
|
||||||
|
}
|
||||||
|
@ -14,3 +14,9 @@ func CreateMysqlTestContainer() (func(), error) {
|
|||||||
// Empty function for MySQL
|
// Empty function for MySQL
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CreateRedisTestContainer() (func(), string, error) {
|
||||||
|
return func() {
|
||||||
|
// Empty function for Redis
|
||||||
|
}, "", nil
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user