mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 19:00:50 +01:00
112 lines
3.0 KiB
Go
112 lines
3.0 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/netbirdio/netbird/management/server/store"
|
|
"github.com/netbirdio/netbird/management/server/types"
|
|
)
|
|
|
|
// AccountRequest holds the result channel to return the requested account.
|
|
type AccountRequest struct {
|
|
AccountID string
|
|
ResultChan chan *AccountResult
|
|
}
|
|
|
|
// AccountResult holds the account data or an error.
|
|
type AccountResult struct {
|
|
Account *types.Account
|
|
Err error
|
|
}
|
|
|
|
type AccountRequestBuffer struct {
|
|
store store.Store
|
|
getAccountRequests map[string][]*AccountRequest
|
|
mu sync.Mutex
|
|
getAccountRequestCh chan *AccountRequest
|
|
bufferInterval time.Duration
|
|
}
|
|
|
|
func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer {
|
|
bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
|
|
bufferInterval, err := time.ParseDuration(bufferIntervalStr)
|
|
if err != nil {
|
|
if bufferIntervalStr != "" {
|
|
log.WithContext(ctx).Warnf("failed to parse account request buffer interval: %s", err)
|
|
}
|
|
bufferInterval = 100 * time.Millisecond
|
|
}
|
|
|
|
log.WithContext(ctx).Infof("set account request buffer interval to %s", bufferInterval)
|
|
|
|
ac := AccountRequestBuffer{
|
|
store: store,
|
|
getAccountRequests: make(map[string][]*AccountRequest),
|
|
getAccountRequestCh: make(chan *AccountRequest),
|
|
bufferInterval: bufferInterval,
|
|
}
|
|
|
|
go ac.processGetAccountRequests(ctx)
|
|
|
|
return &ac
|
|
}
|
|
func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) {
|
|
req := &AccountRequest{
|
|
AccountID: accountID,
|
|
ResultChan: make(chan *AccountResult, 1),
|
|
}
|
|
|
|
log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
|
|
startTime := time.Now()
|
|
ac.getAccountRequestCh <- req
|
|
|
|
result := <-req.ResultChan
|
|
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
|
|
return result.Account, result.Err
|
|
}
|
|
|
|
func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {
|
|
ac.mu.Lock()
|
|
requests := ac.getAccountRequests[accountID]
|
|
delete(ac.getAccountRequests, accountID)
|
|
ac.mu.Unlock()
|
|
|
|
if len(requests) == 0 {
|
|
return
|
|
}
|
|
|
|
startTime := time.Now()
|
|
account, err := ac.store.GetAccount(ctx, accountID)
|
|
log.WithContext(ctx).Tracef("getting account %s in batch took %s", accountID, time.Since(startTime))
|
|
result := &AccountResult{Account: account, Err: err}
|
|
|
|
for _, req := range requests {
|
|
req.ResultChan <- result
|
|
close(req.ResultChan)
|
|
}
|
|
}
|
|
|
|
func (ac *AccountRequestBuffer) processGetAccountRequests(ctx context.Context) {
|
|
for {
|
|
select {
|
|
case req := <-ac.getAccountRequestCh:
|
|
ac.mu.Lock()
|
|
ac.getAccountRequests[req.AccountID] = append(ac.getAccountRequests[req.AccountID], req)
|
|
if len(ac.getAccountRequests[req.AccountID]) == 1 {
|
|
go func(ctx context.Context, accountID string) {
|
|
time.Sleep(ac.bufferInterval)
|
|
ac.processGetAccountBatch(ctx, accountID)
|
|
}(ctx, req.AccountID)
|
|
}
|
|
ac.mu.Unlock()
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|