mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-23 14:28:51 +01:00
[management] Add buffering for getAccount requests during login (#2449)
This commit is contained in:
parent
8c2d37d3fc
commit
3ed90728e6
@ -161,6 +161,8 @@ type DefaultAccountManager struct {
|
|||||||
eventStore activity.Store
|
eventStore activity.Store
|
||||||
geo *geolocation.Geolocation
|
geo *geolocation.Geolocation
|
||||||
|
|
||||||
|
cache *AccountCache
|
||||||
|
|
||||||
// singleAccountMode indicates whether the instance has a single account.
|
// singleAccountMode indicates whether the instance has a single account.
|
||||||
// If true, then every new user will end up under the same account.
|
// If true, then every new user will end up under the same account.
|
||||||
// This value will be set to false if management service has more than one account.
|
// This value will be set to false if management service has more than one account.
|
||||||
@ -967,6 +969,7 @@ func BuildManager(
|
|||||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||||
integratedPeerValidator: integratedPeerValidator,
|
integratedPeerValidator: integratedPeerValidator,
|
||||||
metrics: metrics,
|
metrics: metrics,
|
||||||
|
cache: NewAccountCache(ctx, store),
|
||||||
}
|
}
|
||||||
allAccounts := store.GetAllAccounts(ctx)
|
allAccounts := store.GetAllAccounts(ctx)
|
||||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||||
|
106
management/server/account_cache.go
Normal file
106
management/server/account_cache.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AccountRequest holds the result channel to return the requested account.
|
||||||
|
type AccountRequest struct {
|
||||||
|
AccountID string
|
||||||
|
ResultChan chan *AccountResult
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountResult holds the account data or an error.
|
||||||
|
type AccountResult struct {
|
||||||
|
Account *Account
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccountCache struct {
|
||||||
|
store Store
|
||||||
|
getAccountRequests map[string][]*AccountRequest
|
||||||
|
mu sync.Mutex
|
||||||
|
getAccountRequestCh chan *AccountRequest
|
||||||
|
bufferInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAccountCache(ctx context.Context, store Store) *AccountCache {
|
||||||
|
bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
|
||||||
|
bufferInterval, err := time.ParseDuration(bufferIntervalStr)
|
||||||
|
if err != nil && bufferIntervalStr != "" {
|
||||||
|
log.WithContext(ctx).Warnf("failed to parse account cache buffer interval: %s", err)
|
||||||
|
bufferInterval = 300 * time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Infof("set account cache buffer interval to %s", bufferInterval)
|
||||||
|
|
||||||
|
ac := AccountCache{
|
||||||
|
store: store,
|
||||||
|
getAccountRequests: make(map[string][]*AccountRequest),
|
||||||
|
getAccountRequestCh: make(chan *AccountRequest),
|
||||||
|
bufferInterval: bufferInterval,
|
||||||
|
}
|
||||||
|
|
||||||
|
go ac.processGetAccountRequests(ctx)
|
||||||
|
|
||||||
|
return &ac
|
||||||
|
}
|
||||||
|
func (ac *AccountCache) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) {
|
||||||
|
req := &AccountRequest{
|
||||||
|
AccountID: accountID,
|
||||||
|
ResultChan: make(chan *AccountResult, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
|
||||||
|
startTime := time.Now()
|
||||||
|
ac.getAccountRequestCh <- req
|
||||||
|
|
||||||
|
result := <-req.ResultChan
|
||||||
|
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
|
||||||
|
return result.Account, result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ac *AccountCache) processGetAccountBatch(ctx context.Context, accountID string) {
|
||||||
|
ac.mu.Lock()
|
||||||
|
requests := ac.getAccountRequests[accountID]
|
||||||
|
delete(ac.getAccountRequests, accountID)
|
||||||
|
ac.mu.Unlock()
|
||||||
|
|
||||||
|
if len(requests) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
account, err := ac.store.GetAccount(ctx, accountID)
|
||||||
|
log.WithContext(ctx).Tracef("getting account %s in batch took %s", accountID, time.Since(startTime))
|
||||||
|
result := &AccountResult{Account: account, Err: err}
|
||||||
|
|
||||||
|
for _, req := range requests {
|
||||||
|
req.ResultChan <- result
|
||||||
|
close(req.ResultChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ac *AccountCache) processGetAccountRequests(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case req := <-ac.getAccountRequestCh:
|
||||||
|
ac.mu.Lock()
|
||||||
|
ac.getAccountRequests[req.AccountID] = append(ac.getAccountRequests[req.AccountID], req)
|
||||||
|
if len(ac.getAccountRequests[req.AccountID]) == 1 {
|
||||||
|
go func(ctx context.Context, accountID string) {
|
||||||
|
time.Sleep(ac.bufferInterval)
|
||||||
|
ac.processGetAccountBatch(ctx, accountID)
|
||||||
|
}(ctx, req.AccountID)
|
||||||
|
}
|
||||||
|
ac.mu.Unlock()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -3,13 +3,17 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@ -24,6 +28,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type TestingT interface {
|
||||||
|
require.TestingT
|
||||||
|
Helper()
|
||||||
|
Cleanup(func())
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
kaep = keepalive.EnforcementPolicy{
|
kaep = keepalive.EnforcementPolicy{
|
||||||
MinTime: 15 * time.Second,
|
MinTime: 15 * time.Second,
|
||||||
@ -86,7 +96,7 @@ func Test_SyncProtocol(t *testing.T) {
|
|||||||
defer func() {
|
defer func() {
|
||||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||||
}()
|
}()
|
||||||
mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
|
mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{
|
||||||
Stuns: []*Host{{
|
Stuns: []*Host{{
|
||||||
Proto: "udp",
|
Proto: "udp",
|
||||||
URI: "stun:stun.wiretrustee.com:3468",
|
URI: "stun:stun.wiretrustee.com:3468",
|
||||||
@ -402,7 +412,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
|
func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -485,7 +495,7 @@ func testSyncStatusRace(t *testing.T) {
|
|||||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||||
}()
|
}()
|
||||||
|
|
||||||
mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
|
mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{
|
||||||
Stuns: []*Host{{
|
Stuns: []*Host{{
|
||||||
Proto: "udp",
|
Proto: "udp",
|
||||||
URI: "stun:stun.wiretrustee.com:3468",
|
URI: "stun:stun.wiretrustee.com:3468",
|
||||||
@ -545,7 +555,6 @@ func testSyncStatusRace(t *testing.T) {
|
|||||||
|
|
||||||
ctx2, cancelFunc2 := context.WithCancel(context.Background())
|
ctx2, cancelFunc2 := context.WithCancel(context.Background())
|
||||||
|
|
||||||
//client.
|
|
||||||
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
|
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
|
||||||
WgPubKey: concurrentPeerKey2.PublicKey().String(),
|
WgPubKey: concurrentPeerKey2.PublicKey().String(),
|
||||||
Body: message2,
|
Body: message2,
|
||||||
@ -626,3 +635,208 @@ func testSyncStatusRace(t *testing.T) {
|
|||||||
t.Fatal("Peer should be connected")
|
t.Fatal("Peer should be connected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_LoginPerformance(t *testing.T) {
|
||||||
|
if os.Getenv("CI") == "true" {
|
||||||
|
t.Skip("Skipping on CI")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
||||||
|
|
||||||
|
benchCases := []struct {
|
||||||
|
name string
|
||||||
|
peers int
|
||||||
|
accounts int
|
||||||
|
}{
|
||||||
|
// {"XXS", 5, 1},
|
||||||
|
// {"XS", 10, 1},
|
||||||
|
// {"S", 100, 1},
|
||||||
|
// {"M", 250, 1},
|
||||||
|
// {"L", 500, 1},
|
||||||
|
// {"XL", 750, 1},
|
||||||
|
{"XXL", 1000, 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
t.Run(bc.name, func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
dir := t.TempDir()
|
||||||
|
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||||
|
}()
|
||||||
|
|
||||||
|
mgmtServer, am, _, err := startManagementForTest(t, &Config{
|
||||||
|
Stuns: []*Host{{
|
||||||
|
Proto: "udp",
|
||||||
|
URI: "stun:stun.wiretrustee.com:3468",
|
||||||
|
}},
|
||||||
|
TURNConfig: &TURNConfig{
|
||||||
|
TimeBasedCredentials: false,
|
||||||
|
CredentialsTTL: util.Duration{},
|
||||||
|
Secret: "whatever",
|
||||||
|
Turns: []*Host{{
|
||||||
|
Proto: "udp",
|
||||||
|
URI: "turn:stun.wiretrustee.com:3468",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
Signal: &Host{
|
||||||
|
Proto: "http",
|
||||||
|
URI: "signal.wiretrustee.com:10000",
|
||||||
|
},
|
||||||
|
Datadir: dir,
|
||||||
|
HttpConfig: nil,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer mgmtServer.GracefulStop()
|
||||||
|
|
||||||
|
var counter int32
|
||||||
|
var counterStart int32
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var mu sync.Mutex
|
||||||
|
messageCalls := []func() error{}
|
||||||
|
for j := 0; j < bc.accounts; j++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(j int, counter *int32, counterStart *int32) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("account creation failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("error creating setup key: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < bc.peers; i++ {
|
||||||
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("failed to generate key: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := &mgmtProto.PeerSystemMeta{
|
||||||
|
Hostname: key.PublicKey().String(),
|
||||||
|
GoOS: runtime.GOOS,
|
||||||
|
OS: runtime.GOOS,
|
||||||
|
Core: "core",
|
||||||
|
Platform: "platform",
|
||||||
|
Kernel: "kernel",
|
||||||
|
WiretrusteeVersion: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
peerLogin := PeerLogin{
|
||||||
|
WireGuardPubKey: key.String(),
|
||||||
|
SSHKey: "random",
|
||||||
|
Meta: extractPeerMeta(context.Background(), meta),
|
||||||
|
SetupKey: setupKey.Key,
|
||||||
|
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
login := func() error {
|
||||||
|
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("failed to login peer: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
atomic.AddInt32(counter, 1)
|
||||||
|
if *counter%100 == 0 {
|
||||||
|
t.Logf("finished %d login calls", *counter)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
messageCalls = append(messageCalls, login)
|
||||||
|
mu.Unlock()
|
||||||
|
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("failed to login peer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.AddInt32(counterStart, 1)
|
||||||
|
if *counterStart%100 == 0 {
|
||||||
|
t.Logf("registered %d peers", *counterStart)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(j, &counter, &counterStart)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
t.Logf("prepared %d login calls", len(messageCalls))
|
||||||
|
testLoginPerformance(t, messageCalls)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testLoginPerformance(t *testing.T, loginCalls []func() error) {
|
||||||
|
t.Helper()
|
||||||
|
wgSetup := sync.WaitGroup{}
|
||||||
|
startChan := make(chan struct{})
|
||||||
|
|
||||||
|
wgDone := sync.WaitGroup{}
|
||||||
|
durations := []time.Duration{}
|
||||||
|
l := sync.Mutex{}
|
||||||
|
|
||||||
|
for i, function := range loginCalls {
|
||||||
|
wgSetup.Add(1)
|
||||||
|
wgDone.Add(1)
|
||||||
|
go func(function func() error, i int) {
|
||||||
|
defer wgDone.Done()
|
||||||
|
wgSetup.Done()
|
||||||
|
|
||||||
|
<-startChan
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
err := function()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
l.Lock()
|
||||||
|
durations = append(durations, duration)
|
||||||
|
l.Unlock()
|
||||||
|
}(function, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgSetup.Wait()
|
||||||
|
t.Logf("starting login calls")
|
||||||
|
close(startChan)
|
||||||
|
wgDone.Wait()
|
||||||
|
var tMin, tMax, tSum time.Duration
|
||||||
|
for i, d := range durations {
|
||||||
|
if i == 0 {
|
||||||
|
tMin = d
|
||||||
|
tMax = d
|
||||||
|
tSum = d
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if d < tMin {
|
||||||
|
tMin = d
|
||||||
|
}
|
||||||
|
if d > tMax {
|
||||||
|
tMax = d
|
||||||
|
}
|
||||||
|
tSum += d
|
||||||
|
}
|
||||||
|
tAvg := tSum / time.Duration(len(durations))
|
||||||
|
t.Logf("Min: %v, Max: %v, Avg: %v", tMin, tMax, tAvg)
|
||||||
|
}
|
||||||
|
@ -714,7 +714,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
|||||||
unlockPeer()
|
unlockPeer()
|
||||||
unlockPeer = nil
|
unlockPeer = nil
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.cache.GetAccountWithBackpressure(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user