fix for bandwidth limit check in CanCreateShare; improved code cleanliness (#606)

This commit is contained in:
Michael Quigley 2024-06-07 11:38:49 -04:00
parent 26acd4f5a6
commit aee973379c
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
2 changed files with 55 additions and 84 deletions

View File

@ -87,25 +87,14 @@ func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, _ s
return false, err
}
if ul.resource.IsGlobal() {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(acctId, trx); err == nil && !empty {
lj, err := a.str.FindLatestBandwidthLimitJournalForGlobal(acctId, trx)
if err != nil {
return false, err
}
if lj.Action == store.LimitLimitAction {
return false, nil
}
bwcs := ul.toBandwidthArray(backendMode)
for _, bwc := range bwcs {
latestJe, err := a.isBandwidthClassLimitedForAccount(acctId, bwc, trx)
if err != nil {
return false, err
}
} else {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, ul.resource.GetLimitClassId(), trx); err == nil && !empty {
lj, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, ul.resource.GetLimitClassId(), trx)
if err != nil {
return false, err
}
if lj.Action == store.LimitLimitAction {
return false, nil
}
if latestJe != nil {
return false, nil
}
}
@ -261,50 +250,25 @@ func (a *Agent) enforce(u *metrics.Usage) error {
return err
}
exceededLc, rxBytes, txBytes, err := a.hasExceededBandwidthLimit(u, ul.toBandwidthArray(sdk.BackendMode(shr.BackendMode)))
exceededBwc, rxBytes, txBytes, err := a.anyBandwidthLimitExceeded(u, ul.toBandwidthArray(sdk.BackendMode(shr.BackendMode)))
if err != nil {
return errors.Wrap(err, "error checking limit classes")
}
if exceededLc != nil {
enforced := false
var enforcedAt time.Time
if exceededLc.IsGlobal() {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(int(u.AccountId), trx); err == nil && !empty {
if latest, err := a.str.FindLatestBandwidthLimitJournalForGlobal(int(u.AccountId), trx); err == nil {
enforced = latest.Action == exceededLc.GetLimitAction()
enforcedAt = latest.UpdatedAt
logrus.Debugf("limit '%v' already applied (enforced: %t)", exceededLc, enforced)
} else {
logrus.Errorf("error getting latest global bandwidth journal entry: %v", err)
}
} else {
logrus.Debugf("no bandwidth limit journal entry for '%v'", exceededLc)
}
} else {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil && !empty {
if latest, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil {
enforced = latest.Action == exceededLc.GetLimitAction()
enforcedAt = latest.UpdatedAt
logrus.Debugf("limit '%v' already applied (enforced: %t)", exceededLc, enforced)
} else {
logrus.Errorf("error getting latest bandwidth limit journal entry for limit class '%d': %v", exceededLc.GetLimitClassId(), err)
}
} else {
logrus.Debugf("no bandwidth limit journal entry for '%v'", exceededLc)
}
if exceededBwc != nil {
latestJe, err := a.isBandwidthClassLimitedForAccount(int(u.AccountId), exceededBwc, trx)
if err != nil {
return err
}
if !enforced {
if latestJe == nil {
je := &store.BandwidthLimitJournalEntry{
AccountId: int(u.AccountId),
RxBytes: rxBytes,
TxBytes: txBytes,
Action: exceededLc.GetLimitAction(),
Action: exceededBwc.GetLimitAction(),
}
if !exceededLc.IsGlobal() {
lcId := exceededLc.GetLimitClassId()
if !exceededBwc.IsGlobal() {
lcId := exceededBwc.GetLimitClassId()
je.LimitClassId = &lcId
}
if _, err := a.str.CreateBandwidthLimitJournalEntry(je, trx); err != nil {
@ -314,17 +278,17 @@ func (a *Agent) enforce(u *metrics.Usage) error {
if err != nil {
return err
}
switch exceededLc.GetLimitAction() {
switch exceededBwc.GetLimitAction() {
case store.LimitLimitAction:
for _, limitAction := range a.limitActions {
if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, ul, trx); err != nil {
if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededBwc, ul, trx); err != nil {
return errors.Wrapf(err, "%v", reflect.TypeOf(limitAction).String())
}
}
case store.WarningLimitAction:
for _, warningAction := range a.warningActions {
if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, ul, trx); err != nil {
if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededBwc, ul, trx); err != nil {
return errors.Wrapf(err, "%v", reflect.TypeOf(warningAction).String())
}
}
@ -333,7 +297,7 @@ func (a *Agent) enforce(u *metrics.Usage) error {
return err
}
} else {
logrus.Debugf("already enforced limit for account '%d' at %v", u.AccountId, enforcedAt)
logrus.Debugf("limit '%v' already applied for '%v' (at: %v)", exceededBwc, acct.Email, latestJe.CreatedAt)
}
}
@ -406,7 +370,7 @@ func (a *Agent) relax() error {
}
used := periodBw[bwc.GetPeriodMinutes()]
if !a.limitExceeded(used.rx, used.tx, bwc) {
if !a.transferBytesExceeded(used.rx, used.tx, bwc) {
if bwc.GetLimitAction() == store.LimitLimitAction {
logrus.Infof("relaxing limit '%v' for '%v'", bwc.String(), accounts[bwje.AccountId].Email)
for _, action := range a.relaxActions {
@ -447,7 +411,38 @@ func (a *Agent) relax() error {
return nil
}
func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.BandwidthClass) (store.BandwidthClass, int64, int64, error) {
func (a *Agent) isBandwidthClassLimitedForAccount(acctId int, bwc store.BandwidthClass, trx *sqlx.Tx) (*store.BandwidthLimitJournalEntry, error) {
if bwc.IsGlobal() {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(acctId, trx); err == nil && !empty {
je, err := a.str.FindLatestBandwidthLimitJournalForGlobal(acctId, trx)
if err != nil {
return nil, err
}
if je.Action == store.LimitLimitAction {
logrus.Infof("account '#%d' over bandwidth for global bandwidth class '%v'", acctId, bwc)
return je, nil
}
} else if err != nil {
return nil, err
}
} else {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, bwc.GetLimitClassId(), trx); err == nil && !empty {
je, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, bwc.GetLimitClassId(), trx)
if err != nil {
return nil, err
}
if je.Action == store.LimitLimitAction {
logrus.Infof("account '#%d' over bandwidth for limit class '%v'", acctId, bwc)
return je, nil
}
} else if err != nil {
return nil, err
}
}
return nil, nil
}
func (a *Agent) anyBandwidthLimitExceeded(u *metrics.Usage, bwcs []store.BandwidthClass) (store.BandwidthClass, int64, int64, error) {
periodBw := make(map[int]struct {
rx int64
tx int64
@ -473,7 +468,7 @@ func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.Bandwid
}
period := periodBw[bwc.GetPeriodMinutes()]
if a.limitExceeded(period.rx, period.tx, bwc) {
if a.transferBytesExceeded(period.rx, period.tx, bwc) {
selectedLc = bwc
rxBytes = period.rx
txBytes = period.tx
@ -489,24 +484,7 @@ func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.Bandwid
return selectedLc, rxBytes, txBytes, nil
}
func (a *Agent) bandwidthClassPoints(bwc store.BandwidthClass) int {
points := 0
if !bwc.IsGlobal() {
points++
}
if bwc.GetLimitAction() == store.WarningLimitAction {
points++
}
if bwc.GetLimitAction() == store.LimitLimitAction {
points += 2
}
if bwc.GetBackendMode() != "" {
points += 10
}
return points
}
func (a *Agent) limitExceeded(rx, tx int64, bwc store.BandwidthClass) bool {
func (a *Agent) transferBytesExceeded(rx, tx int64, bwc store.BandwidthClass) bool {
if bwc.GetTxBytes() != store.Unlimited && tx >= bwc.GetTxBytes() {
return true
}

View File

@ -106,13 +106,6 @@ func (str *Store) FindAllLatestBandwidthLimitJournal(trx *sqlx.Tx) ([]*Bandwidth
return jes, nil
}
func (str *Store) DeleteBandwidthLimitJournal(acctId int, trx *sqlx.Tx) error {
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1", acctId); err != nil {
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d", acctId)
}
return nil
}
func (str *Store) DeleteBandwidthLimitJournalEntryForGlobal(acctId int, trx *sqlx.Tx) error {
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1 and limit_class_id is null", acctId); err != nil {
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d and limit_class_id is null", acctId)