mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-23 14:28:51 +01:00
[management] Add buffering for getAccount requests during login (#2449)
This commit is contained in:
parent
8c2d37d3fc
commit
3ed90728e6
@ -161,6 +161,8 @@ type DefaultAccountManager struct {
|
||||
eventStore activity.Store
|
||||
geo *geolocation.Geolocation
|
||||
|
||||
cache *AccountCache
|
||||
|
||||
// singleAccountMode indicates whether the instance has a single account.
|
||||
// If true, then every new user will end up under the same account.
|
||||
// This value will be set to false if management service has more than one account.
|
||||
@ -967,6 +969,7 @@ func BuildManager(
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
metrics: metrics,
|
||||
cache: NewAccountCache(ctx, store),
|
||||
}
|
||||
allAccounts := store.GetAllAccounts(ctx)
|
||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||
|
106
management/server/account_cache.go
Normal file
106
management/server/account_cache.go
Normal file
@ -0,0 +1,106 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// AccountRequest holds the result channel to return the requested account.
|
||||
type AccountRequest struct {
|
||||
AccountID string
|
||||
ResultChan chan *AccountResult
|
||||
}
|
||||
|
||||
// AccountResult holds the account data or an error.
|
||||
type AccountResult struct {
|
||||
Account *Account
|
||||
Err error
|
||||
}
|
||||
|
||||
type AccountCache struct {
|
||||
store Store
|
||||
getAccountRequests map[string][]*AccountRequest
|
||||
mu sync.Mutex
|
||||
getAccountRequestCh chan *AccountRequest
|
||||
bufferInterval time.Duration
|
||||
}
|
||||
|
||||
func NewAccountCache(ctx context.Context, store Store) *AccountCache {
|
||||
bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
|
||||
bufferInterval, err := time.ParseDuration(bufferIntervalStr)
|
||||
if err != nil && bufferIntervalStr != "" {
|
||||
log.WithContext(ctx).Warnf("failed to parse account cache buffer interval: %s", err)
|
||||
bufferInterval = 300 * time.Millisecond
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("set account cache buffer interval to %s", bufferInterval)
|
||||
|
||||
ac := AccountCache{
|
||||
store: store,
|
||||
getAccountRequests: make(map[string][]*AccountRequest),
|
||||
getAccountRequestCh: make(chan *AccountRequest),
|
||||
bufferInterval: bufferInterval,
|
||||
}
|
||||
|
||||
go ac.processGetAccountRequests(ctx)
|
||||
|
||||
return &ac
|
||||
}
|
||||
func (ac *AccountCache) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) {
|
||||
req := &AccountRequest{
|
||||
AccountID: accountID,
|
||||
ResultChan: make(chan *AccountResult, 1),
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
|
||||
startTime := time.Now()
|
||||
ac.getAccountRequestCh <- req
|
||||
|
||||
result := <-req.ResultChan
|
||||
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
|
||||
return result.Account, result.Err
|
||||
}
|
||||
|
||||
func (ac *AccountCache) processGetAccountBatch(ctx context.Context, accountID string) {
|
||||
ac.mu.Lock()
|
||||
requests := ac.getAccountRequests[accountID]
|
||||
delete(ac.getAccountRequests, accountID)
|
||||
ac.mu.Unlock()
|
||||
|
||||
if len(requests) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
account, err := ac.store.GetAccount(ctx, accountID)
|
||||
log.WithContext(ctx).Tracef("getting account %s in batch took %s", accountID, time.Since(startTime))
|
||||
result := &AccountResult{Account: account, Err: err}
|
||||
|
||||
for _, req := range requests {
|
||||
req.ResultChan <- result
|
||||
close(req.ResultChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (ac *AccountCache) processGetAccountRequests(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case req := <-ac.getAccountRequestCh:
|
||||
ac.mu.Lock()
|
||||
ac.getAccountRequests[req.AccountID] = append(ac.getAccountRequests[req.AccountID], req)
|
||||
if len(ac.getAccountRequests[req.AccountID]) == 1 {
|
||||
go func(ctx context.Context, accountID string) {
|
||||
time.Sleep(ac.bufferInterval)
|
||||
ac.processGetAccountBatch(ctx, accountID)
|
||||
}(ctx, req.AccountID)
|
||||
}
|
||||
ac.mu.Unlock()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
@ -3,13 +3,17 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
@ -24,6 +28,12 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
type TestingT interface {
|
||||
require.TestingT
|
||||
Helper()
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
@ -86,7 +96,7 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
defer func() {
|
||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
}()
|
||||
mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
|
||||
mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
@ -402,7 +412,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
|
||||
func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
@ -485,7 +495,7 @@ func testSyncStatusRace(t *testing.T) {
|
||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
}()
|
||||
|
||||
mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
|
||||
mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
@ -545,7 +555,6 @@ func testSyncStatusRace(t *testing.T) {
|
||||
|
||||
ctx2, cancelFunc2 := context.WithCancel(context.Background())
|
||||
|
||||
//client.
|
||||
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: concurrentPeerKey2.PublicKey().String(),
|
||||
Body: message2,
|
||||
@ -574,7 +583,7 @@ func testSyncStatusRace(t *testing.T) {
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
|
||||
//client.
|
||||
// client.
|
||||
sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
|
||||
Body: message,
|
||||
@ -626,3 +635,208 @@ func testSyncStatusRace(t *testing.T) {
|
||||
t.Fatal("Peer should be connected")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_LoginPerformance(t *testing.T) {
|
||||
if os.Getenv("CI") == "true" {
|
||||
t.Skip("Skipping on CI")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
|
||||
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
accounts int
|
||||
}{
|
||||
// {"XXS", 5, 1},
|
||||
// {"XS", 10, 1},
|
||||
// {"S", 100, 1},
|
||||
// {"M", 250, 1},
|
||||
// {"L", 500, 1},
|
||||
// {"XL", 750, 1},
|
||||
{"XXL", 1000, 5},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
for _, bc := range benchCases {
|
||||
t.Run(bc.name, func(t *testing.T) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
}()
|
||||
|
||||
mgmtServer, am, _, err := startManagementForTest(t, &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
}},
|
||||
TURNConfig: &TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{},
|
||||
Secret: "whatever",
|
||||
Turns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "turn:stun.wiretrustee.com:3468",
|
||||
}},
|
||||
},
|
||||
Signal: &Host{
|
||||
Proto: "http",
|
||||
URI: "signal.wiretrustee.com:10000",
|
||||
},
|
||||
Datadir: dir,
|
||||
HttpConfig: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
var counter int32
|
||||
var counterStart int32
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
messageCalls := []func() error{}
|
||||
for j := 0; j < bc.accounts; j++ {
|
||||
wg.Add(1)
|
||||
go func(j int, counter *int32, counterStart *int32) {
|
||||
defer wg.Done()
|
||||
|
||||
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
|
||||
if err != nil {
|
||||
t.Logf("account creation failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false)
|
||||
if err != nil {
|
||||
t.Logf("error creating setup key: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < bc.peers; i++ {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Logf("failed to generate key: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
meta := &mgmtProto.PeerSystemMeta{
|
||||
Hostname: key.PublicKey().String(),
|
||||
GoOS: runtime.GOOS,
|
||||
OS: runtime.GOOS,
|
||||
Core: "core",
|
||||
Platform: "platform",
|
||||
Kernel: "kernel",
|
||||
WiretrusteeVersion: "",
|
||||
}
|
||||
|
||||
peerLogin := PeerLogin{
|
||||
WireGuardPubKey: key.String(),
|
||||
SSHKey: "random",
|
||||
Meta: extractPeerMeta(context.Background(), meta),
|
||||
SetupKey: setupKey.Key,
|
||||
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||
}
|
||||
|
||||
login := func() error {
|
||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
if err != nil {
|
||||
t.Logf("failed to login peer: %v", err)
|
||||
return err
|
||||
}
|
||||
atomic.AddInt32(counter, 1)
|
||||
if *counter%100 == 0 {
|
||||
t.Logf("finished %d login calls", *counter)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
messageCalls = append(messageCalls, login)
|
||||
mu.Unlock()
|
||||
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
|
||||
if err != nil {
|
||||
t.Logf("failed to login peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddInt32(counterStart, 1)
|
||||
if *counterStart%100 == 0 {
|
||||
t.Logf("registered %d peers", *counterStart)
|
||||
}
|
||||
}
|
||||
}(j, &counter, &counterStart)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
t.Logf("prepared %d login calls", len(messageCalls))
|
||||
testLoginPerformance(t, messageCalls)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testLoginPerformance(t *testing.T, loginCalls []func() error) {
|
||||
t.Helper()
|
||||
wgSetup := sync.WaitGroup{}
|
||||
startChan := make(chan struct{})
|
||||
|
||||
wgDone := sync.WaitGroup{}
|
||||
durations := []time.Duration{}
|
||||
l := sync.Mutex{}
|
||||
|
||||
for i, function := range loginCalls {
|
||||
wgSetup.Add(1)
|
||||
wgDone.Add(1)
|
||||
go func(function func() error, i int) {
|
||||
defer wgDone.Done()
|
||||
wgSetup.Done()
|
||||
|
||||
<-startChan
|
||||
start := time.Now()
|
||||
|
||||
err := function()
|
||||
if err != nil {
|
||||
t.Logf("Error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
l.Lock()
|
||||
durations = append(durations, duration)
|
||||
l.Unlock()
|
||||
}(function, i)
|
||||
}
|
||||
|
||||
wgSetup.Wait()
|
||||
t.Logf("starting login calls")
|
||||
close(startChan)
|
||||
wgDone.Wait()
|
||||
var tMin, tMax, tSum time.Duration
|
||||
for i, d := range durations {
|
||||
if i == 0 {
|
||||
tMin = d
|
||||
tMax = d
|
||||
tSum = d
|
||||
continue
|
||||
}
|
||||
if d < tMin {
|
||||
tMin = d
|
||||
}
|
||||
if d > tMax {
|
||||
tMax = d
|
||||
}
|
||||
tSum += d
|
||||
}
|
||||
tAvg := tSum / time.Duration(len(durations))
|
||||
t.Logf("Min: %v, Max: %v, Avg: %v", tMin, tMax, tAvg)
|
||||
}
|
||||
|
@ -714,7 +714,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
unlockPeer()
|
||||
unlockPeer = nil
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
account, err := am.cache.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user