package server

import (
	"context"
	"os"
	"sync"
	"time"

	log "github.com/sirupsen/logrus"

	"github.com/netbirdio/netbird/management/server/store"
	"github.com/netbirdio/netbird/management/server/types"
)

// AccountRequest holds the result channel to return the requested account.
type AccountRequest struct {
	AccountID  string
	ResultChan chan *AccountResult
}

// AccountResult holds the account data or an error.
type AccountResult struct {
	Account *types.Account
	Err     error
}

type AccountRequestBuffer struct {
	store               store.Store
	getAccountRequests  map[string][]*AccountRequest
	mu                  sync.Mutex
	getAccountRequestCh chan *AccountRequest
	bufferInterval      time.Duration
}

func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer {
	bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
	bufferInterval, err := time.ParseDuration(bufferIntervalStr)
	if err != nil {
		if bufferIntervalStr != "" {
			log.WithContext(ctx).Warnf("failed to parse account request buffer interval: %s", err)
		}
		bufferInterval = 100 * time.Millisecond
	}

	log.WithContext(ctx).Infof("set account request buffer interval to %s", bufferInterval)

	ac := AccountRequestBuffer{
		store:               store,
		getAccountRequests:  make(map[string][]*AccountRequest),
		getAccountRequestCh: make(chan *AccountRequest),
		bufferInterval:      bufferInterval,
	}

	go ac.processGetAccountRequests(ctx)

	return &ac
}
func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) {
	req := &AccountRequest{
		AccountID:  accountID,
		ResultChan: make(chan *AccountResult, 1),
	}

	log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
	startTime := time.Now()
	ac.getAccountRequestCh <- req

	result := <-req.ResultChan
	log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
	return result.Account, result.Err
}

func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {
	ac.mu.Lock()
	requests := ac.getAccountRequests[accountID]
	delete(ac.getAccountRequests, accountID)
	ac.mu.Unlock()

	if len(requests) == 0 {
		return
	}

	startTime := time.Now()
	account, err := ac.store.GetAccount(ctx, accountID)
	log.WithContext(ctx).Tracef("getting account %s in batch took %s", accountID, time.Since(startTime))
	result := &AccountResult{Account: account, Err: err}

	for _, req := range requests {
		req.ResultChan <- result
		close(req.ResultChan)
	}
}

func (ac *AccountRequestBuffer) processGetAccountRequests(ctx context.Context) {
	for {
		select {
		case req := <-ac.getAccountRequestCh:
			ac.mu.Lock()
			ac.getAccountRequests[req.AccountID] = append(ac.getAccountRequests[req.AccountID], req)
			if len(ac.getAccountRequests[req.AccountID]) == 1 {
				go func(ctx context.Context, accountID string) {
					time.Sleep(ac.bufferInterval)
					ac.processGetAccountBatch(ctx, accountID)
				}(ctx, req.AccountID)
			}
			ac.mu.Unlock()
		case <-ctx.Done():
			return
		}
	}
}