[management] Add buffering for getAccount requests during login (#2449)

This commit is contained in:
pascal-fischer 2024-08-20 20:06:01 +02:00 committed by GitHub
parent 8c2d37d3fc
commit 3ed90728e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 329 additions and 6 deletions

View File

@ -161,6 +161,8 @@ type DefaultAccountManager struct {
eventStore activity.Store eventStore activity.Store
geo *geolocation.Geolocation geo *geolocation.Geolocation
cache *AccountCache
// singleAccountMode indicates whether the instance has a single account. // singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account. // If true, then every new user will end up under the same account.
// This value will be set to false if management service has more than one account. // This value will be set to false if management service has more than one account.
@ -967,6 +969,7 @@ func BuildManager(
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator, integratedPeerValidator: integratedPeerValidator,
metrics: metrics, metrics: metrics,
cache: NewAccountCache(ctx, store),
} }
allAccounts := store.GetAllAccounts(ctx) allAccounts := store.GetAllAccounts(ctx)
// enable single account mode only if configured by user and number of existing accounts is not grater than 1 // enable single account mode only if configured by user and number of existing accounts is not grater than 1

View File

@ -0,0 +1,106 @@
package server
import (
"context"
"os"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// AccountRequest holds the result channel to return the requested account.
type AccountRequest struct {
AccountID string
ResultChan chan *AccountResult
}
// AccountResult holds the account data or an error.
type AccountResult struct {
Account *Account
Err error
}
type AccountCache struct {
store Store
getAccountRequests map[string][]*AccountRequest
mu sync.Mutex
getAccountRequestCh chan *AccountRequest
bufferInterval time.Duration
}
func NewAccountCache(ctx context.Context, store Store) *AccountCache {
bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
bufferInterval, err := time.ParseDuration(bufferIntervalStr)
if err != nil && bufferIntervalStr != "" {
log.WithContext(ctx).Warnf("failed to parse account cache buffer interval: %s", err)
bufferInterval = 300 * time.Millisecond
}
log.WithContext(ctx).Infof("set account cache buffer interval to %s", bufferInterval)
ac := AccountCache{
store: store,
getAccountRequests: make(map[string][]*AccountRequest),
getAccountRequestCh: make(chan *AccountRequest),
bufferInterval: bufferInterval,
}
go ac.processGetAccountRequests(ctx)
return &ac
}
func (ac *AccountCache) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) {
req := &AccountRequest{
AccountID: accountID,
ResultChan: make(chan *AccountResult, 1),
}
log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
startTime := time.Now()
ac.getAccountRequestCh <- req
result := <-req.ResultChan
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
func (ac *AccountCache) processGetAccountBatch(ctx context.Context, accountID string) {
ac.mu.Lock()
requests := ac.getAccountRequests[accountID]
delete(ac.getAccountRequests, accountID)
ac.mu.Unlock()
if len(requests) == 0 {
return
}
startTime := time.Now()
account, err := ac.store.GetAccount(ctx, accountID)
log.WithContext(ctx).Tracef("getting account %s in batch took %s", accountID, time.Since(startTime))
result := &AccountResult{Account: account, Err: err}
for _, req := range requests {
req.ResultChan <- result
close(req.ResultChan)
}
}
func (ac *AccountCache) processGetAccountRequests(ctx context.Context) {
for {
select {
case req := <-ac.getAccountRequestCh:
ac.mu.Lock()
ac.getAccountRequests[req.AccountID] = append(ac.getAccountRequests[req.AccountID], req)
if len(ac.getAccountRequests[req.AccountID]) == 1 {
go func(ctx context.Context, accountID string) {
time.Sleep(ac.bufferInterval)
ac.processGetAccountBatch(ctx, accountID)
}(ctx, req.AccountID)
}
ac.mu.Unlock()
case <-ctx.Done():
return
}
}
}

View File

@ -3,13 +3,17 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sync"
"sync/atomic"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -24,6 +28,12 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
type TestingT interface {
require.TestingT
Helper()
Cleanup(func())
}
var ( var (
kaep = keepalive.EnforcementPolicy{ kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second, MinTime: 15 * time.Second,
@ -86,7 +96,7 @@ func Test_SyncProtocol(t *testing.T) {
defer func() { defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint os.Remove(filepath.Join(dir, "store.json")) //nolint
}() }()
mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{ mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{
Stuns: []*Host{{ Stuns: []*Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468", URI: "stun:stun.wiretrustee.com:3468",
@ -402,7 +412,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
} }
} }
func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
@ -485,7 +495,7 @@ func testSyncStatusRace(t *testing.T) {
os.Remove(filepath.Join(dir, "store.json")) //nolint os.Remove(filepath.Join(dir, "store.json")) //nolint
}() }()
mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{ mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{
Stuns: []*Host{{ Stuns: []*Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468", URI: "stun:stun.wiretrustee.com:3468",
@ -545,7 +555,6 @@ func testSyncStatusRace(t *testing.T) {
ctx2, cancelFunc2 := context.WithCancel(context.Background()) ctx2, cancelFunc2 := context.WithCancel(context.Background())
//client.
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{ sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
WgPubKey: concurrentPeerKey2.PublicKey().String(), WgPubKey: concurrentPeerKey2.PublicKey().String(),
Body: message2, Body: message2,
@ -626,3 +635,208 @@ func testSyncStatusRace(t *testing.T) {
t.Fatal("Peer should be connected") t.Fatal("Peer should be connected")
} }
} }
func Test_LoginPerformance(t *testing.T) {
if os.Getenv("CI") == "true" {
t.Skip("Skipping on CI")
}
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
benchCases := []struct {
name string
peers int
accounts int
}{
// {"XXS", 5, 1},
// {"XS", 10, 1},
// {"S", 100, 1},
// {"M", 250, 1},
// {"L", 500, 1},
// {"XL", 750, 1},
{"XXL", 1000, 5},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
t.Run(bc.name, func(t *testing.T) {
t.Helper()
dir := t.TempDir()
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
if err != nil {
t.Fatal(err)
}
defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
mgmtServer, am, _, err := startManagementForTest(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
}},
TURNConfig: &TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Proto: "udp",
URI: "turn:stun.wiretrustee.com:3468",
}},
},
Signal: &Host{
Proto: "http",
URI: "signal.wiretrustee.com:10000",
},
Datadir: dir,
HttpConfig: nil,
})
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
var counter int32
var counterStart int32
var wg sync.WaitGroup
var mu sync.Mutex
messageCalls := []func() error{}
for j := 0; j < bc.accounts; j++ {
wg.Add(1)
go func(j int, counter *int32, counterStart *int32) {
defer wg.Done()
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
if err != nil {
t.Logf("account creation failed: %v", err)
return
}
setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false)
if err != nil {
t.Logf("error creating setup key: %v", err)
return
}
for i := 0; i < bc.peers; i++ {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Logf("failed to generate key: %v", err)
return
}
meta := &mgmtProto.PeerSystemMeta{
Hostname: key.PublicKey().String(),
GoOS: runtime.GOOS,
OS: runtime.GOOS,
Core: "core",
Platform: "platform",
Kernel: "kernel",
WiretrusteeVersion: "",
}
peerLogin := PeerLogin{
WireGuardPubKey: key.String(),
SSHKey: "random",
Meta: extractPeerMeta(context.Background(), meta),
SetupKey: setupKey.Key,
ConnectionIP: net.IP{1, 1, 1, 1},
}
login := func() error {
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return err
}
atomic.AddInt32(counter, 1)
if *counter%100 == 0 {
t.Logf("finished %d login calls", *counter)
}
return nil
}
mu.Lock()
messageCalls = append(messageCalls, login)
mu.Unlock()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
}
}(j, &counter, &counterStart)
}
wg.Wait()
t.Logf("prepared %d login calls", len(messageCalls))
testLoginPerformance(t, messageCalls)
})
}
}
func testLoginPerformance(t *testing.T, loginCalls []func() error) {
t.Helper()
wgSetup := sync.WaitGroup{}
startChan := make(chan struct{})
wgDone := sync.WaitGroup{}
durations := []time.Duration{}
l := sync.Mutex{}
for i, function := range loginCalls {
wgSetup.Add(1)
wgDone.Add(1)
go func(function func() error, i int) {
defer wgDone.Done()
wgSetup.Done()
<-startChan
start := time.Now()
err := function()
if err != nil {
t.Logf("Error: %v", err)
return
}
duration := time.Since(start)
l.Lock()
durations = append(durations, duration)
l.Unlock()
}(function, i)
}
wgSetup.Wait()
t.Logf("starting login calls")
close(startChan)
wgDone.Wait()
var tMin, tMax, tSum time.Duration
for i, d := range durations {
if i == 0 {
tMin = d
tMax = d
tSum = d
continue
}
if d < tMin {
tMin = d
}
if d > tMax {
tMax = d
}
tSum += d
}
tAvg := tSum / time.Duration(len(durations))
t.Logf("Min: %v, Max: %v, Avg: %v", tMin, tMax, tAvg)
}

View File

@ -714,7 +714,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
unlockPeer() unlockPeer()
unlockPeer = nil unlockPeer = nil
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.cache.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }