fix for bandwidth limit check in CanCreateShare; improved code cleanliness (#606)

This commit is contained in:
Michael Quigley 2024-06-07 11:38:49 -04:00
parent 26acd4f5a6
commit aee973379c
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
2 changed files with 55 additions and 84 deletions

View File

@ -87,25 +87,14 @@ func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, _ s
return false, err return false, err
} }
if ul.resource.IsGlobal() { bwcs := ul.toBandwidthArray(backendMode)
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(acctId, trx); err == nil && !empty { for _, bwc := range bwcs {
lj, err := a.str.FindLatestBandwidthLimitJournalForGlobal(acctId, trx) latestJe, err := a.isBandwidthClassLimitedForAccount(acctId, bwc, trx)
if err != nil { if err != nil {
return false, err return false, err
}
if lj.Action == store.LimitLimitAction {
return false, nil
}
} }
} else { if latestJe != nil {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, ul.resource.GetLimitClassId(), trx); err == nil && !empty { return false, nil
lj, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, ul.resource.GetLimitClassId(), trx)
if err != nil {
return false, err
}
if lj.Action == store.LimitLimitAction {
return false, nil
}
} }
} }
@ -261,50 +250,25 @@ func (a *Agent) enforce(u *metrics.Usage) error {
return err return err
} }
exceededLc, rxBytes, txBytes, err := a.hasExceededBandwidthLimit(u, ul.toBandwidthArray(sdk.BackendMode(shr.BackendMode))) exceededBwc, rxBytes, txBytes, err := a.anyBandwidthLimitExceeded(u, ul.toBandwidthArray(sdk.BackendMode(shr.BackendMode)))
if err != nil { if err != nil {
return errors.Wrap(err, "error checking limit classes") return errors.Wrap(err, "error checking limit classes")
} }
if exceededLc != nil { if exceededBwc != nil {
enforced := false latestJe, err := a.isBandwidthClassLimitedForAccount(int(u.AccountId), exceededBwc, trx)
var enforcedAt time.Time if err != nil {
return err
if exceededLc.IsGlobal() {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(int(u.AccountId), trx); err == nil && !empty {
if latest, err := a.str.FindLatestBandwidthLimitJournalForGlobal(int(u.AccountId), trx); err == nil {
enforced = latest.Action == exceededLc.GetLimitAction()
enforcedAt = latest.UpdatedAt
logrus.Debugf("limit '%v' already applied (enforced: %t)", exceededLc, enforced)
} else {
logrus.Errorf("error getting latest global bandwidth journal entry: %v", err)
}
} else {
logrus.Debugf("no bandwidth limit journal entry for '%v'", exceededLc)
}
} else {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil && !empty {
if latest, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil {
enforced = latest.Action == exceededLc.GetLimitAction()
enforcedAt = latest.UpdatedAt
logrus.Debugf("limit '%v' already applied (enforced: %t)", exceededLc, enforced)
} else {
logrus.Errorf("error getting latest bandwidth limit journal entry for limit class '%d': %v", exceededLc.GetLimitClassId(), err)
}
} else {
logrus.Debugf("no bandwidth limit journal entry for '%v'", exceededLc)
}
} }
if latestJe == nil {
if !enforced {
je := &store.BandwidthLimitJournalEntry{ je := &store.BandwidthLimitJournalEntry{
AccountId: int(u.AccountId), AccountId: int(u.AccountId),
RxBytes: rxBytes, RxBytes: rxBytes,
TxBytes: txBytes, TxBytes: txBytes,
Action: exceededLc.GetLimitAction(), Action: exceededBwc.GetLimitAction(),
} }
if !exceededLc.IsGlobal() { if !exceededBwc.IsGlobal() {
lcId := exceededLc.GetLimitClassId() lcId := exceededBwc.GetLimitClassId()
je.LimitClassId = &lcId je.LimitClassId = &lcId
} }
if _, err := a.str.CreateBandwidthLimitJournalEntry(je, trx); err != nil { if _, err := a.str.CreateBandwidthLimitJournalEntry(je, trx); err != nil {
@ -314,17 +278,17 @@ func (a *Agent) enforce(u *metrics.Usage) error {
if err != nil { if err != nil {
return err return err
} }
switch exceededLc.GetLimitAction() { switch exceededBwc.GetLimitAction() {
case store.LimitLimitAction: case store.LimitLimitAction:
for _, limitAction := range a.limitActions { for _, limitAction := range a.limitActions {
if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, ul, trx); err != nil { if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededBwc, ul, trx); err != nil {
return errors.Wrapf(err, "%v", reflect.TypeOf(limitAction).String()) return errors.Wrapf(err, "%v", reflect.TypeOf(limitAction).String())
} }
} }
case store.WarningLimitAction: case store.WarningLimitAction:
for _, warningAction := range a.warningActions { for _, warningAction := range a.warningActions {
if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, ul, trx); err != nil { if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededBwc, ul, trx); err != nil {
return errors.Wrapf(err, "%v", reflect.TypeOf(warningAction).String()) return errors.Wrapf(err, "%v", reflect.TypeOf(warningAction).String())
} }
} }
@ -333,7 +297,7 @@ func (a *Agent) enforce(u *metrics.Usage) error {
return err return err
} }
} else { } else {
logrus.Debugf("already enforced limit for account '%d' at %v", u.AccountId, enforcedAt) logrus.Debugf("limit '%v' already applied for '%v' (at: %v)", exceededBwc, acct.Email, latestJe.CreatedAt)
} }
} }
@ -406,7 +370,7 @@ func (a *Agent) relax() error {
} }
used := periodBw[bwc.GetPeriodMinutes()] used := periodBw[bwc.GetPeriodMinutes()]
if !a.limitExceeded(used.rx, used.tx, bwc) { if !a.transferBytesExceeded(used.rx, used.tx, bwc) {
if bwc.GetLimitAction() == store.LimitLimitAction { if bwc.GetLimitAction() == store.LimitLimitAction {
logrus.Infof("relaxing limit '%v' for '%v'", bwc.String(), accounts[bwje.AccountId].Email) logrus.Infof("relaxing limit '%v' for '%v'", bwc.String(), accounts[bwje.AccountId].Email)
for _, action := range a.relaxActions { for _, action := range a.relaxActions {
@ -447,7 +411,38 @@ func (a *Agent) relax() error {
return nil return nil
} }
func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.BandwidthClass) (store.BandwidthClass, int64, int64, error) { func (a *Agent) isBandwidthClassLimitedForAccount(acctId int, bwc store.BandwidthClass, trx *sqlx.Tx) (*store.BandwidthLimitJournalEntry, error) {
if bwc.IsGlobal() {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(acctId, trx); err == nil && !empty {
je, err := a.str.FindLatestBandwidthLimitJournalForGlobal(acctId, trx)
if err != nil {
return nil, err
}
if je.Action == store.LimitLimitAction {
logrus.Infof("account '#%d' over bandwidth for global bandwidth class '%v'", acctId, bwc)
return je, nil
}
} else if err != nil {
return nil, err
}
} else {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, bwc.GetLimitClassId(), trx); err == nil && !empty {
je, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, bwc.GetLimitClassId(), trx)
if err != nil {
return nil, err
}
if je.Action == store.LimitLimitAction {
logrus.Infof("account '#%d' over bandwidth for limit class '%v'", acctId, bwc)
return je, nil
}
} else if err != nil {
return nil, err
}
}
return nil, nil
}
func (a *Agent) anyBandwidthLimitExceeded(u *metrics.Usage, bwcs []store.BandwidthClass) (store.BandwidthClass, int64, int64, error) {
periodBw := make(map[int]struct { periodBw := make(map[int]struct {
rx int64 rx int64
tx int64 tx int64
@ -473,7 +468,7 @@ func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.Bandwid
} }
period := periodBw[bwc.GetPeriodMinutes()] period := periodBw[bwc.GetPeriodMinutes()]
if a.limitExceeded(period.rx, period.tx, bwc) { if a.transferBytesExceeded(period.rx, period.tx, bwc) {
selectedLc = bwc selectedLc = bwc
rxBytes = period.rx rxBytes = period.rx
txBytes = period.tx txBytes = period.tx
@ -489,24 +484,7 @@ func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.Bandwid
return selectedLc, rxBytes, txBytes, nil return selectedLc, rxBytes, txBytes, nil
} }
func (a *Agent) bandwidthClassPoints(bwc store.BandwidthClass) int { func (a *Agent) transferBytesExceeded(rx, tx int64, bwc store.BandwidthClass) bool {
points := 0
if !bwc.IsGlobal() {
points++
}
if bwc.GetLimitAction() == store.WarningLimitAction {
points++
}
if bwc.GetLimitAction() == store.LimitLimitAction {
points += 2
}
if bwc.GetBackendMode() != "" {
points += 10
}
return points
}
func (a *Agent) limitExceeded(rx, tx int64, bwc store.BandwidthClass) bool {
if bwc.GetTxBytes() != store.Unlimited && tx >= bwc.GetTxBytes() { if bwc.GetTxBytes() != store.Unlimited && tx >= bwc.GetTxBytes() {
return true return true
} }

View File

@ -106,13 +106,6 @@ func (str *Store) FindAllLatestBandwidthLimitJournal(trx *sqlx.Tx) ([]*Bandwidth
return jes, nil return jes, nil
} }
func (str *Store) DeleteBandwidthLimitJournal(acctId int, trx *sqlx.Tx) error {
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1", acctId); err != nil {
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d", acctId)
}
return nil
}
func (str *Store) DeleteBandwidthLimitJournalEntryForGlobal(acctId int, trx *sqlx.Tx) error { func (str *Store) DeleteBandwidthLimitJournalEntryForGlobal(acctId int, trx *sqlx.Tx) error {
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1 and limit_class_id is null", acctId); err != nil { if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1 and limit_class_id is null", acctId); err != nil {
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d and limit_class_id is null", acctId) return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d and limit_class_id is null", acctId)