mirror of
https://github.com/openziti/zrok.git
synced 2025-02-02 03:20:26 +01:00
massive bandwidth limits rewrite to support limit classes (#606)
This commit is contained in:
parent
bb63921e42
commit
97dbd197d6
@ -217,87 +217,82 @@ func (a *Agent) enforce(u *metrics.Usage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
shr, err := a.str.FindShareWithToken(u.ShareToken, trx)
|
||||
shr, err := a.str.FindShareWithTokenEvenIfDeleted(u.ShareToken, trx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logrus.Infof("share: '%v', shareMode: '%v', backendMode: '%v'", shr.Token, shr.ShareMode, shr.BackendMode)
|
||||
logrus.Debugf("share: '%v', shareMode: '%v', backendMode: '%v'", shr.Token, shr.ShareMode, shr.BackendMode)
|
||||
|
||||
if enforce, warning, rxBytes, txBytes, err := a.checkBandwidthLimit(u.AccountId); err == nil {
|
||||
if enforce {
|
||||
enforced := false
|
||||
var enforcedAt time.Time
|
||||
if empty, err := a.str.IsBandwidthLimitJournalEmpty(int(u.AccountId), trx); err == nil && !empty {
|
||||
if latest, err := a.str.FindLatestBandwidthLimitJournal(int(u.AccountId), trx); err == nil {
|
||||
enforced = latest.Action == store.LimitLimitAction
|
||||
alcs, err := a.str.FindAppliedLimitClassesForAccount(int(u.AccountId), trx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exceededLc, rxBytes, txBytes, err := a.isOverLimitClass(u, alcs)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error checking limit classes")
|
||||
}
|
||||
|
||||
if exceededLc != nil {
|
||||
enforced := false
|
||||
var enforcedAt time.Time
|
||||
|
||||
if exceededLc.IsGlobal() {
|
||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(int(u.AccountId), trx); err == nil && !empty {
|
||||
if latest, err := a.str.FindLatestBandwidthLimitJournalForGlobal(int(u.AccountId), trx); err == nil {
|
||||
enforced = latest.Action == exceededLc.GetLimitAction()
|
||||
enforcedAt = latest.UpdatedAt
|
||||
}
|
||||
}
|
||||
|
||||
if !enforced {
|
||||
_, err := a.str.CreateBandwidthLimitJournalEntry(&store.BandwidthLimitJournalEntry{
|
||||
AccountId: int(u.AccountId),
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
Action: store.LimitLimitAction,
|
||||
}, trx)
|
||||
if err != nil {
|
||||
return err
|
||||
} else {
|
||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil && !empty {
|
||||
if latest, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil {
|
||||
enforced = latest.Action == exceededLc.GetLimitAction()
|
||||
enforcedAt = latest.UpdatedAt
|
||||
}
|
||||
acct, err := a.str.GetAccount(int(u.AccountId), trx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, action := range a.limitActions {
|
||||
if err := action.HandleAccount(acct, rxBytes, txBytes, a.cfg.Bandwidth, trx); err != nil {
|
||||
return errors.Wrapf(err, "%v", reflect.TypeOf(action).String())
|
||||
}
|
||||
}
|
||||
if err := trx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
logrus.Debugf("already enforced limit for account '#%d' at %v", u.AccountId, enforcedAt)
|
||||
}
|
||||
|
||||
} else if warning {
|
||||
warned := false
|
||||
var warnedAt time.Time
|
||||
if empty, err := a.str.IsBandwidthLimitJournalEmpty(int(u.AccountId), trx); err == nil && !empty {
|
||||
if latest, err := a.str.FindLatestBandwidthLimitJournal(int(u.AccountId), trx); err == nil {
|
||||
warned = latest.Action == store.WarningLimitAction || latest.Action == store.LimitLimitAction
|
||||
warnedAt = latest.UpdatedAt
|
||||
}
|
||||
}
|
||||
|
||||
if !warned {
|
||||
_, err := a.str.CreateBandwidthLimitJournalEntry(&store.BandwidthLimitJournalEntry{
|
||||
AccountId: int(u.AccountId),
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
Action: store.WarningLimitAction,
|
||||
}, trx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
acct, err := a.str.GetAccount(int(u.AccountId), trx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, action := range a.warningActions {
|
||||
if err := action.HandleAccount(acct, rxBytes, txBytes, a.cfg.Bandwidth, trx); err != nil {
|
||||
return errors.Wrapf(err, "%v", reflect.TypeOf(action).String())
|
||||
}
|
||||
}
|
||||
if err := trx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
logrus.Debugf("already warned account '#%d' at %v", u.AccountId, warnedAt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logrus.Error(err)
|
||||
|
||||
if !enforced {
|
||||
je := &store.BandwidthLimitJournalEntry{
|
||||
AccountId: int(u.AccountId),
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
Action: exceededLc.GetLimitAction(),
|
||||
}
|
||||
if !exceededLc.IsGlobal() {
|
||||
lcId := exceededLc.GetLimitClassId()
|
||||
je.LimitClassId = &lcId
|
||||
}
|
||||
_, err := a.str.CreateBandwidthLimitJournalEntry(je, trx)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
acct, err := a.str.GetAccount(int(u.AccountId), trx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch exceededLc.GetLimitAction() {
|
||||
case store.LimitLimitAction:
|
||||
for _, limitAction := range a.limitActions {
|
||||
if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, trx); err != nil {
|
||||
return errors.Wrapf(err, "%v", reflect.TypeOf(limitAction).String())
|
||||
}
|
||||
}
|
||||
|
||||
case store.WarningLimitAction:
|
||||
for _, warningAction := range a.warningActions {
|
||||
if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, trx); err != nil {
|
||||
return errors.Wrapf(err, "%v", reflect.TypeOf(warningAction).String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := trx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
logrus.Debugf("already enforced limit for account '%d' at %v", u.AccountId, enforcedAt)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -314,36 +309,77 @@ func (a *Agent) relax() error {
|
||||
|
||||
commit := false
|
||||
|
||||
if aljs, err := a.str.FindAllLatestBandwidthLimitJournal(trx); err == nil {
|
||||
for _, alj := range aljs {
|
||||
if acct, err := a.str.GetAccount(alj.AccountId, trx); err == nil {
|
||||
if alj.Action == store.WarningLimitAction || alj.Action == store.LimitLimitAction {
|
||||
if enforce, warning, rxBytes, txBytes, err := a.checkBandwidthLimit(int64(alj.AccountId)); err == nil {
|
||||
if !enforce && !warning {
|
||||
if alj.Action == store.LimitLimitAction {
|
||||
// run relax actions for account
|
||||
for _, action := range a.relaxActions {
|
||||
if err := action.HandleAccount(acct, rxBytes, txBytes, a.cfg.Bandwidth, trx); err != nil {
|
||||
return errors.Wrapf(err, "%v", reflect.TypeOf(action).String())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logrus.Infof("relaxing warning for '%v'", acct.Email)
|
||||
}
|
||||
if err := a.str.DeleteBandwidthLimitJournal(acct.Id, trx); err == nil {
|
||||
commit = true
|
||||
} else {
|
||||
logrus.Errorf("error deleting account_limit_journal for '%v': %v", acct.Email, err)
|
||||
}
|
||||
} else {
|
||||
logrus.Infof("account '%v' still over limit", acct.Email)
|
||||
}
|
||||
} else {
|
||||
logrus.Errorf("error checking account limit for '%v': %v", acct.Email, err)
|
||||
}
|
||||
if bwjes, err := a.str.FindAllBandwidthLimitJournal(trx); err == nil {
|
||||
periodBw := make(map[int]struct {
|
||||
rx int64
|
||||
tx int64
|
||||
})
|
||||
|
||||
accounts := make(map[int]*store.Account)
|
||||
|
||||
for _, bwje := range bwjes {
|
||||
if _, found := accounts[bwje.AccountId]; !found {
|
||||
if acct, err := a.str.GetAccount(bwje.AccountId, trx); err == nil {
|
||||
accounts[bwje.AccountId] = acct
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var bwc store.BandwidthClass
|
||||
if bwje.LimitClassId != nil {
|
||||
globalBwcs := newConfigBandwidthClasses(a.cfg.Bandwidth)
|
||||
if bwje.Action == store.WarningLimitAction {
|
||||
bwc = globalBwcs[0]
|
||||
} else {
|
||||
bwc = globalBwcs[1]
|
||||
}
|
||||
} else {
|
||||
logrus.Errorf("error getting account for '#%d': %v", alj.AccountId, err)
|
||||
lc, err := a.str.GetLimitClass(*bwje.LimitClassId, trx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bwc = lc
|
||||
}
|
||||
|
||||
if _, found := periodBw[bwc.GetPeriodMinutes()]; !found {
|
||||
rx, tx, err := a.ifx.totalRxTxForAccount(int64(bwje.AccountId), time.Duration(bwc.GetPeriodMinutes())*time.Minute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
periodBw[bwc.GetPeriodMinutes()] = struct {
|
||||
rx int64
|
||||
tx int64
|
||||
}{
|
||||
rx: rx,
|
||||
tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
used := periodBw[bwc.GetPeriodMinutes()]
|
||||
if !a.limitExceeded(used.rx, used.tx, bwc) {
|
||||
if bwc.GetLimitAction() == store.LimitLimitAction {
|
||||
for _, action := range a.relaxActions {
|
||||
if err := action.HandleAccount(accounts[bwje.AccountId], used.rx, used.tx, bwc, trx); err != nil {
|
||||
return errors.Wrapf(err, "%v", reflect.TypeOf(action).String())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logrus.Infof("relaxing warning for '%v'", accounts[bwje.AccountId].Email)
|
||||
}
|
||||
var lcId *int
|
||||
if !bwc.IsGlobal() {
|
||||
newLcId := 0
|
||||
newLcId = bwc.GetLimitClassId()
|
||||
lcId = &newLcId
|
||||
}
|
||||
if err := a.str.DeleteBandwidthLimitJournalEntryForLimitClass(bwje.AccountId, lcId, trx); err == nil {
|
||||
commit = true
|
||||
} else {
|
||||
logrus.Errorf("error deleting bandwidth limit journal entry for '%v': %v", accounts[bwje.AccountId].Email, err)
|
||||
}
|
||||
} else {
|
||||
logrus.Infof("account '%v' still over limit: %v", accounts[bwje.AccountId].Email, bwc)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -359,44 +395,87 @@ func (a *Agent) relax() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) checkBandwidthLimit(acctId int64) (enforce, warning bool, rxBytes, txBytes int64, err error) {
|
||||
period := 24 * time.Hour
|
||||
limit := DefaultBandwidthPerPeriod()
|
||||
if a.cfg.Bandwidth != nil {
|
||||
limit = a.cfg.Bandwidth
|
||||
func (a *Agent) isOverLimitClass(u *metrics.Usage, alcs []*store.LimitClass) (store.BandwidthClass, int64, int64, error) {
|
||||
periodBw := make(map[int]struct {
|
||||
rx int64
|
||||
tx int64
|
||||
})
|
||||
|
||||
var allBwcs []store.BandwidthClass
|
||||
for _, alc := range alcs {
|
||||
allBwcs = append(allBwcs, alc)
|
||||
}
|
||||
if limit.Period > 0 {
|
||||
period = limit.Period
|
||||
}
|
||||
rx, tx, err := a.ifx.totalRxTxForAccount(acctId, period)
|
||||
if err != nil {
|
||||
logrus.Error(err)
|
||||
for _, globBwc := range newConfigBandwidthClasses(a.cfg.Bandwidth) {
|
||||
allBwcs = append(allBwcs, globBwc)
|
||||
}
|
||||
|
||||
enforce, warning = a.checkLimit(limit, rx, tx)
|
||||
return enforce, warning, rx, tx, nil
|
||||
// find period data for each class
|
||||
for _, bwc := range allBwcs {
|
||||
if _, found := periodBw[bwc.GetPeriodMinutes()]; !found {
|
||||
rx, tx, err := a.ifx.totalRxTxForAccount(u.AccountId, time.Minute*time.Duration(bwc.GetPeriodMinutes()))
|
||||
if err != nil {
|
||||
return nil, 0, 0, errors.Wrapf(err, "error getting rx/tx for account '%d'", u.AccountId)
|
||||
}
|
||||
periodBw[bwc.GetPeriodMinutes()] = struct {
|
||||
rx int64
|
||||
tx int64
|
||||
}{
|
||||
rx: rx,
|
||||
tx: tx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// find the highest, most specific limit class that has been exceeded
|
||||
var selectedLc store.BandwidthClass
|
||||
selectedLcPoints := -1
|
||||
var rxBytes int64
|
||||
var txBytes int64
|
||||
for _, bwc := range allBwcs {
|
||||
points := a.bandwidthClassPoints(bwc)
|
||||
if points >= selectedLcPoints {
|
||||
period := periodBw[bwc.GetPeriodMinutes()]
|
||||
if a.limitExceeded(period.rx, period.tx, bwc) {
|
||||
selectedLc = bwc
|
||||
selectedLcPoints = points
|
||||
rxBytes = period.rx
|
||||
txBytes = period.tx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return selectedLc, rxBytes, txBytes, nil
|
||||
}
|
||||
|
||||
func (a *Agent) checkLimit(cfg *BandwidthPerPeriod, rx, tx int64) (enforce, warning bool) {
|
||||
if cfg.Limit.Rx != Unlimited && rx > cfg.Limit.Rx {
|
||||
return true, false
|
||||
func (a *Agent) bandwidthClassPoints(bwc store.BandwidthClass) int {
|
||||
points := 0
|
||||
if !bwc.IsGlobal() {
|
||||
points++
|
||||
}
|
||||
if cfg.Limit.Tx != Unlimited && tx > cfg.Limit.Tx {
|
||||
return true, false
|
||||
if bwc.GetLimitAction() == store.WarningLimitAction {
|
||||
points++
|
||||
}
|
||||
if cfg.Limit.Total != Unlimited && rx+tx > cfg.Limit.Total {
|
||||
return true, false
|
||||
if bwc.GetLimitAction() == store.LimitLimitAction {
|
||||
points += 2
|
||||
}
|
||||
|
||||
if cfg.Warning.Rx != Unlimited && rx > cfg.Warning.Rx {
|
||||
return false, true
|
||||
if bwc.GetShareMode() != "" {
|
||||
points += 5
|
||||
}
|
||||
if cfg.Warning.Tx != Unlimited && tx > cfg.Warning.Tx {
|
||||
return false, true
|
||||
if bwc.GetBackendMode() != "" {
|
||||
points += 10
|
||||
}
|
||||
if cfg.Warning.Total != Unlimited && rx+tx > cfg.Warning.Total {
|
||||
return false, true
|
||||
}
|
||||
|
||||
return false, false
|
||||
return points
|
||||
}
|
||||
|
||||
func (a *Agent) limitExceeded(rx, tx int64, bwc store.BandwidthClass) bool {
|
||||
if bwc.GetTxBytes() != Unlimited && tx >= bwc.GetTxBytes() {
|
||||
return true
|
||||
}
|
||||
if bwc.GetRxBytes() != Unlimited && rx >= bwc.GetRxBytes() {
|
||||
return true
|
||||
}
|
||||
if bwc.GetTxBytes() != Unlimited && bwc.GetRxBytes() != Unlimited && tx+rx >= bwc.GetTxBytes()+bwc.GetRxBytes() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ func newLimitAction(str *store.Store, zCfg *zrokEdgeSdk.Config) *limitAction {
|
||||
return &limitAction{str, zCfg}
|
||||
}
|
||||
|
||||
func (a *limitAction) HandleAccount(acct *store.Account, _, _ int64, _ *BandwidthPerPeriod, trx *sqlx.Tx) error {
|
||||
func (a *limitAction) HandleAccount(acct *store.Account, _, _ int64, _ store.BandwidthClass, trx *sqlx.Tx) error {
|
||||
logrus.Infof("limiting '%v'", acct.Email)
|
||||
|
||||
envs, err := a.str.FindEnvironmentsForAccount(acct.Id, trx)
|
||||
|
@ -30,6 +30,10 @@ func (bc *configBandwidthClass) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (bc *configBandwidthClass) GetLimitClassId() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func (bc *configBandwidthClass) GetShareMode() sdk.ShareMode {
|
||||
return ""
|
||||
}
|
||||
|
@ -6,13 +6,13 @@ import (
|
||||
)
|
||||
|
||||
type AccountAction interface {
|
||||
HandleAccount(a *store.Account, rxBytes, txBytes int64, limit *BandwidthPerPeriod, trx *sqlx.Tx) error
|
||||
HandleAccount(a *store.Account, rxBytes, txBytes int64, limit store.BandwidthClass, trx *sqlx.Tx) error
|
||||
}
|
||||
|
||||
type EnvironmentAction interface {
|
||||
HandleEnvironment(e *store.Environment, rxBytes, txBytes int64, limit *BandwidthPerPeriod, trx *sqlx.Tx) error
|
||||
HandleEnvironment(e *store.Environment, rxBytes, txBytes int64, limit store.BandwidthClass, trx *sqlx.Tx) error
|
||||
}
|
||||
|
||||
type ShareAction interface {
|
||||
HandleShare(s *store.Share, rxBytes, txBytes int64, limit *BandwidthPerPeriod, trx *sqlx.Tx) error
|
||||
HandleShare(s *store.Share, rxBytes, txBytes int64, limit store.BandwidthClass, trx *sqlx.Tx) error
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ func newRelaxAction(str *store.Store, zCfg *zrokEdgeSdk.Config) *relaxAction {
|
||||
return &relaxAction{str, zCfg}
|
||||
}
|
||||
|
||||
func (a *relaxAction) HandleAccount(acct *store.Account, _, _ int64, _ *BandwidthPerPeriod, trx *sqlx.Tx) error {
|
||||
func (a *relaxAction) HandleAccount(acct *store.Account, _, _ int64, _ store.BandwidthClass, trx *sqlx.Tx) error {
|
||||
logrus.Infof("relaxing '%v'", acct.Email)
|
||||
|
||||
envs, err := a.str.FindEnvironmentsForAccount(acct.Id, trx)
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/openziti/zrok/util"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"time"
|
||||
)
|
||||
|
||||
type warningAction struct {
|
||||
@ -18,27 +19,27 @@ func newWarningAction(cfg *emailUi.Config, str *store.Store) *warningAction {
|
||||
return &warningAction{str, cfg}
|
||||
}
|
||||
|
||||
func (a *warningAction) HandleAccount(acct *store.Account, rxBytes, txBytes int64, limit *BandwidthPerPeriod, _ *sqlx.Tx) error {
|
||||
func (a *warningAction) HandleAccount(acct *store.Account, rxBytes, txBytes int64, limit store.BandwidthClass, _ *sqlx.Tx) error {
|
||||
logrus.Infof("warning '%v'", acct.Email)
|
||||
|
||||
if a.cfg != nil {
|
||||
rxLimit := "(unlimited bytes)"
|
||||
if limit.Limit.Rx != Unlimited {
|
||||
rxLimit = util.BytesToSize(limit.Limit.Rx)
|
||||
if limit.GetRxBytes() != Unlimited {
|
||||
rxLimit = util.BytesToSize(limit.GetRxBytes())
|
||||
}
|
||||
txLimit := "(unlimited bytes)"
|
||||
if limit.Limit.Tx != Unlimited {
|
||||
txLimit = util.BytesToSize(limit.Limit.Tx)
|
||||
if limit.GetTxBytes() != Unlimited {
|
||||
txLimit = util.BytesToSize(limit.GetTxBytes())
|
||||
}
|
||||
totalLimit := "(unlimited bytes)"
|
||||
if limit.Limit.Total != Unlimited {
|
||||
totalLimit = util.BytesToSize(limit.Limit.Total)
|
||||
if limit.GetTotalBytes() != Unlimited {
|
||||
totalLimit = util.BytesToSize(limit.GetTotalBytes())
|
||||
}
|
||||
|
||||
detail := newDetailMessage()
|
||||
detail = detail.append("Your account has received %v and sent %v (for a total of %v), which has triggered a transfer limit warning.", util.BytesToSize(rxBytes), util.BytesToSize(txBytes), util.BytesToSize(rxBytes+txBytes))
|
||||
detail = detail.append("This zrok instance only allows an account to receive %v, send %v, totalling not more than %v for each %v.", rxLimit, txLimit, totalLimit, limit.Period)
|
||||
detail = detail.append("If you exceed the transfer limit, access to your shares will be temporarily disabled (until the last %v falls below the transfer limit)", limit.Period)
|
||||
detail = detail.append("This zrok instance only allows an account to receive %v, send %v, totalling not more than %v for each %v.", rxLimit, txLimit, totalLimit, time.Duration(limit.GetPeriodMinutes())*time.Minute)
|
||||
detail = detail.append("If you exceed the transfer limit, access to your shares will be temporarily disabled (until the last %v falls below the transfer limit)", time.Duration(limit.GetPeriodMinutes())*time.Minute)
|
||||
|
||||
if err := sendLimitWarningEmail(a.cfg, acct.Email, detail); err != nil {
|
||||
return errors.Wrapf(err, "error sending limit warning email to '%v'", acct.Email)
|
||||
|
@ -23,7 +23,7 @@ func (str *Store) ApplyLimitClass(lc *AppliedLimitClass, trx *sqlx.Tx) (int, err
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (str *Store) FindLimitClassesForAccount(acctId int, trx *sqlx.Tx) ([]*LimitClass, error) {
|
||||
func (str *Store) FindAppliedLimitClassesForAccount(acctId int, trx *sqlx.Tx) ([]*LimitClass, error) {
|
||||
rows, err := trx.Queryx("select limit_classes.* from applied_limit_classes, limit_classes where applied_limit_classes.account_id = $1 and applied_limit_classes.limit_class_id = limit_classes.id", acctId)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error finding limit classes for account")
|
||||
|
@ -42,6 +42,54 @@ func (str *Store) FindLatestBandwidthLimitJournal(acctId int, trx *sqlx.Tx) (*Ba
|
||||
return j, nil
|
||||
}
|
||||
|
||||
func (str *Store) IsBandwidthLimitJournalEmptyForGlobal(acctId int, trx *sqlx.Tx) (bool, error) {
|
||||
count := 0
|
||||
if err := trx.QueryRowx("select count(0) from bandwidth_limit_journal where account_id = $1 and limit_class_id is null", acctId).Scan(&count); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count == 0, nil
|
||||
}
|
||||
|
||||
func (str *Store) FindLatestBandwidthLimitJournalForGlobal(acctId int, trx *sqlx.Tx) (*BandwidthLimitJournalEntry, error) {
|
||||
j := &BandwidthLimitJournalEntry{}
|
||||
if err := trx.QueryRowx("select * from bandwidth_limit_journal where account_id = $1 and limit_class_id is null order by id desc limit 1", acctId).Scan(&j); err != nil {
|
||||
return nil, errors.Wrap(err, "error finding bandwidth_limit_journal by account_id for global")
|
||||
}
|
||||
return j, nil
|
||||
}
|
||||
|
||||
func (str *Store) IsBandwidthLimitJournalEmptyForLimitClass(acctId, lcId int, trx *sqlx.Tx) (bool, error) {
|
||||
count := 0
|
||||
if err := trx.QueryRowx("select count(0) from bandwidth_limit_journal where account_id = $1 and limit_class_id = $2", acctId, lcId).Scan(&count); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count == 0, nil
|
||||
}
|
||||
|
||||
func (str *Store) FindLatestBandwidthLimitJournalForLimitClass(acctId, lcId int, trx *sqlx.Tx) (*BandwidthLimitJournalEntry, error) {
|
||||
j := &BandwidthLimitJournalEntry{}
|
||||
if err := trx.QueryRowx("select * from bandwidth_limit_journal where account_id = $1 and limit_class_id = $2 order by id desc limit 1", acctId, lcId).StructScan(j); err != nil {
|
||||
return nil, errors.Wrap(err, "error finding bandwidth_limit_journal by account_id and limit_class_id")
|
||||
}
|
||||
return j, nil
|
||||
}
|
||||
|
||||
func (str *Store) FindAllBandwidthLimitJournal(trx *sqlx.Tx) ([]*BandwidthLimitJournalEntry, error) {
|
||||
rows, err := trx.Queryx("select * from bandwidth_limit_journal")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error finding all from bandwidth_limit_journal")
|
||||
}
|
||||
var jes []*BandwidthLimitJournalEntry
|
||||
for rows.Next() {
|
||||
je := &BandwidthLimitJournalEntry{}
|
||||
if err := rows.StructScan(je); err != nil {
|
||||
return nil, errors.Wrap(err, "error scanning bandwidth_limit_journal")
|
||||
}
|
||||
jes = append(jes, je)
|
||||
}
|
||||
return jes, nil
|
||||
}
|
||||
|
||||
func (str *Store) FindAllLatestBandwidthLimitJournal(trx *sqlx.Tx) ([]*BandwidthLimitJournalEntry, error) {
|
||||
rows, err := trx.Queryx("select id, account_id, limit_class_id, action, rx_bytes, tx_bytes, created_at, updated_at from bandwidth_limit_journal where id in (select max(id) as id from bandwidth_limit_journal group by account_id)")
|
||||
if err != nil {
|
||||
@ -64,3 +112,10 @@ func (str *Store) DeleteBandwidthLimitJournal(acctId int, trx *sqlx.Tx) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (str *Store) DeleteBandwidthLimitJournalEntryForLimitClass(acctId int, lcId *int, trx *sqlx.Tx) error {
|
||||
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1 and limit_class_id = $2", acctId, lcId); err != nil {
|
||||
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d and limit_class_id = %d", acctId, lcId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
type BandwidthClass interface {
|
||||
IsGlobal() bool
|
||||
GetLimitClassId() int
|
||||
GetShareMode() sdk.ShareMode
|
||||
GetBackendMode() sdk.BackendMode
|
||||
GetPeriodMinutes() int
|
||||
@ -33,35 +34,39 @@ type LimitClass struct {
|
||||
LimitAction LimitAction
|
||||
}
|
||||
|
||||
func (lc *LimitClass) IsGlobal() bool {
|
||||
func (lc LimitClass) IsGlobal() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (lc *LimitClass) GetShareMode() sdk.ShareMode {
|
||||
func (lc LimitClass) GetLimitClassId() int {
|
||||
return lc.Id
|
||||
}
|
||||
|
||||
func (lc LimitClass) GetShareMode() sdk.ShareMode {
|
||||
return lc.ShareMode
|
||||
}
|
||||
|
||||
func (lc *LimitClass) GetBackendMode() sdk.BackendMode {
|
||||
func (lc LimitClass) GetBackendMode() sdk.BackendMode {
|
||||
return lc.BackendMode
|
||||
}
|
||||
|
||||
func (lc *LimitClass) GetPeriodMinutes() int {
|
||||
func (lc LimitClass) GetPeriodMinutes() int {
|
||||
return lc.PeriodMinutes
|
||||
}
|
||||
|
||||
func (lc *LimitClass) GetRxBytes() int64 {
|
||||
func (lc LimitClass) GetRxBytes() int64 {
|
||||
return lc.RxBytes
|
||||
}
|
||||
|
||||
func (lc *LimitClass) GetTxBytes() int64 {
|
||||
func (lc LimitClass) GetTxBytes() int64 {
|
||||
return lc.TxBytes
|
||||
}
|
||||
|
||||
func (lc *LimitClass) GetTotalBytes() int64 {
|
||||
func (lc LimitClass) GetTotalBytes() int64 {
|
||||
return lc.TotalBytes
|
||||
}
|
||||
|
||||
func (lc *LimitClass) GetLimitAction() LimitAction {
|
||||
func (lc LimitClass) GetLimitAction() LimitAction {
|
||||
return lc.LimitAction
|
||||
}
|
||||
|
||||
@ -74,6 +79,8 @@ func (lc LimitClass) String() string {
|
||||
return string(out)
|
||||
}
|
||||
|
||||
var _ BandwidthClass = (*LimitClass)(nil)
|
||||
|
||||
func (str *Store) CreateLimitClass(lc *LimitClass, trx *sqlx.Tx) (int, error) {
|
||||
stmt, err := trx.Prepare("insert into limit_classes (share_mode, backend_mode, environments, shares, reserved_shares, unique_names, period_minutes, rx_bytes, tx_bytes, total_bytes, limit_action) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) returning id")
|
||||
if err != nil {
|
||||
@ -85,3 +92,11 @@ func (str *Store) CreateLimitClass(lc *LimitClass, trx *sqlx.Tx) (int, error) {
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (str *Store) GetLimitClass(lcId int, trx *sqlx.Tx) (*LimitClass, error) {
|
||||
lc := &LimitClass{}
|
||||
if err := trx.QueryRowx("select * from limit_classes where id = $1", lcId).StructScan(lc); err != nil {
|
||||
return nil, errors.Wrap(err, "error selecting limit_class by id")
|
||||
}
|
||||
return lc, nil
|
||||
}
|
||||
|
@ -65,6 +65,14 @@ func (str *Store) FindShareWithToken(shrToken string, tx *sqlx.Tx) (*Share, erro
|
||||
return shr, nil
|
||||
}
|
||||
|
||||
func (str *Store) FindShareWithTokenEvenIfDeleted(shrToken string, tx *sqlx.Tx) (*Share, error) {
|
||||
shr := &Share{}
|
||||
if err := tx.QueryRowx("select * from shares where token = $1", shrToken).StructScan(shr); err != nil {
|
||||
return nil, errors.Wrap(err, "error selecting share by token, even if deleted")
|
||||
}
|
||||
return shr, nil
|
||||
}
|
||||
|
||||
func (str *Store) ShareWithTokenExists(shrToken string, tx *sqlx.Tx) (bool, error) {
|
||||
count := 0
|
||||
if err := tx.QueryRowx("select count(0) from shares where token = $1 and not deleted", shrToken).Scan(&count); err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user