mirror of
https://github.com/openziti/zrok.git
synced 2024-11-22 08:03:49 +01:00
fix for bandwidth limit check in CanCreateShare; improved code cleanliness (#606)
This commit is contained in:
parent
26acd4f5a6
commit
aee973379c
@ -87,25 +87,14 @@ func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, _ s
|
||||
return false, err
|
||||
}
|
||||
|
||||
if ul.resource.IsGlobal() {
|
||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(acctId, trx); err == nil && !empty {
|
||||
lj, err := a.str.FindLatestBandwidthLimitJournalForGlobal(acctId, trx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if lj.Action == store.LimitLimitAction {
|
||||
return false, nil
|
||||
}
|
||||
bwcs := ul.toBandwidthArray(backendMode)
|
||||
for _, bwc := range bwcs {
|
||||
latestJe, err := a.isBandwidthClassLimitedForAccount(acctId, bwc, trx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
} else {
|
||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, ul.resource.GetLimitClassId(), trx); err == nil && !empty {
|
||||
lj, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, ul.resource.GetLimitClassId(), trx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if lj.Action == store.LimitLimitAction {
|
||||
return false, nil
|
||||
}
|
||||
if latestJe != nil {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -261,50 +250,25 @@ func (a *Agent) enforce(u *metrics.Usage) error {
|
||||
return err
|
||||
}
|
||||
|
||||
exceededLc, rxBytes, txBytes, err := a.hasExceededBandwidthLimit(u, ul.toBandwidthArray(sdk.BackendMode(shr.BackendMode)))
|
||||
exceededBwc, rxBytes, txBytes, err := a.anyBandwidthLimitExceeded(u, ul.toBandwidthArray(sdk.BackendMode(shr.BackendMode)))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error checking limit classes")
|
||||
}
|
||||
|
||||
if exceededLc != nil {
|
||||
enforced := false
|
||||
var enforcedAt time.Time
|
||||
|
||||
if exceededLc.IsGlobal() {
|
||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(int(u.AccountId), trx); err == nil && !empty {
|
||||
if latest, err := a.str.FindLatestBandwidthLimitJournalForGlobal(int(u.AccountId), trx); err == nil {
|
||||
enforced = latest.Action == exceededLc.GetLimitAction()
|
||||
enforcedAt = latest.UpdatedAt
|
||||
logrus.Debugf("limit '%v' already applied (enforced: %t)", exceededLc, enforced)
|
||||
} else {
|
||||
logrus.Errorf("error getting latest global bandwidth journal entry: %v", err)
|
||||
}
|
||||
} else {
|
||||
logrus.Debugf("no bandwidth limit journal entry for '%v'", exceededLc)
|
||||
}
|
||||
} else {
|
||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil && !empty {
|
||||
if latest, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil {
|
||||
enforced = latest.Action == exceededLc.GetLimitAction()
|
||||
enforcedAt = latest.UpdatedAt
|
||||
logrus.Debugf("limit '%v' already applied (enforced: %t)", exceededLc, enforced)
|
||||
} else {
|
||||
logrus.Errorf("error getting latest bandwidth limit journal entry for limit class '%d': %v", exceededLc.GetLimitClassId(), err)
|
||||
}
|
||||
} else {
|
||||
logrus.Debugf("no bandwidth limit journal entry for '%v'", exceededLc)
|
||||
}
|
||||
if exceededBwc != nil {
|
||||
latestJe, err := a.isBandwidthClassLimitedForAccount(int(u.AccountId), exceededBwc, trx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !enforced {
|
||||
if latestJe == nil {
|
||||
je := &store.BandwidthLimitJournalEntry{
|
||||
AccountId: int(u.AccountId),
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
Action: exceededLc.GetLimitAction(),
|
||||
Action: exceededBwc.GetLimitAction(),
|
||||
}
|
||||
if !exceededLc.IsGlobal() {
|
||||
lcId := exceededLc.GetLimitClassId()
|
||||
if !exceededBwc.IsGlobal() {
|
||||
lcId := exceededBwc.GetLimitClassId()
|
||||
je.LimitClassId = &lcId
|
||||
}
|
||||
if _, err := a.str.CreateBandwidthLimitJournalEntry(je, trx); err != nil {
|
||||
@ -314,17 +278,17 @@ func (a *Agent) enforce(u *metrics.Usage) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch exceededLc.GetLimitAction() {
|
||||
switch exceededBwc.GetLimitAction() {
|
||||
case store.LimitLimitAction:
|
||||
for _, limitAction := range a.limitActions {
|
||||
if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, ul, trx); err != nil {
|
||||
if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededBwc, ul, trx); err != nil {
|
||||
return errors.Wrapf(err, "%v", reflect.TypeOf(limitAction).String())
|
||||
}
|
||||
}
|
||||
|
||||
case store.WarningLimitAction:
|
||||
for _, warningAction := range a.warningActions {
|
||||
if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, ul, trx); err != nil {
|
||||
if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededBwc, ul, trx); err != nil {
|
||||
return errors.Wrapf(err, "%v", reflect.TypeOf(warningAction).String())
|
||||
}
|
||||
}
|
||||
@ -333,7 +297,7 @@ func (a *Agent) enforce(u *metrics.Usage) error {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
logrus.Debugf("already enforced limit for account '%d' at %v", u.AccountId, enforcedAt)
|
||||
logrus.Debugf("limit '%v' already applied for '%v' (at: %v)", exceededBwc, acct.Email, latestJe.CreatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
@ -406,7 +370,7 @@ func (a *Agent) relax() error {
|
||||
}
|
||||
|
||||
used := periodBw[bwc.GetPeriodMinutes()]
|
||||
if !a.limitExceeded(used.rx, used.tx, bwc) {
|
||||
if !a.transferBytesExceeded(used.rx, used.tx, bwc) {
|
||||
if bwc.GetLimitAction() == store.LimitLimitAction {
|
||||
logrus.Infof("relaxing limit '%v' for '%v'", bwc.String(), accounts[bwje.AccountId].Email)
|
||||
for _, action := range a.relaxActions {
|
||||
@ -447,7 +411,38 @@ func (a *Agent) relax() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.BandwidthClass) (store.BandwidthClass, int64, int64, error) {
|
||||
func (a *Agent) isBandwidthClassLimitedForAccount(acctId int, bwc store.BandwidthClass, trx *sqlx.Tx) (*store.BandwidthLimitJournalEntry, error) {
|
||||
if bwc.IsGlobal() {
|
||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(acctId, trx); err == nil && !empty {
|
||||
je, err := a.str.FindLatestBandwidthLimitJournalForGlobal(acctId, trx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if je.Action == store.LimitLimitAction {
|
||||
logrus.Infof("account '#%d' over bandwidth for global bandwidth class '%v'", acctId, bwc)
|
||||
return je, nil
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, bwc.GetLimitClassId(), trx); err == nil && !empty {
|
||||
je, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, bwc.GetLimitClassId(), trx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if je.Action == store.LimitLimitAction {
|
||||
logrus.Infof("account '#%d' over bandwidth for limit class '%v'", acctId, bwc)
|
||||
return je, nil
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Agent) anyBandwidthLimitExceeded(u *metrics.Usage, bwcs []store.BandwidthClass) (store.BandwidthClass, int64, int64, error) {
|
||||
periodBw := make(map[int]struct {
|
||||
rx int64
|
||||
tx int64
|
||||
@ -473,7 +468,7 @@ func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.Bandwid
|
||||
}
|
||||
period := periodBw[bwc.GetPeriodMinutes()]
|
||||
|
||||
if a.limitExceeded(period.rx, period.tx, bwc) {
|
||||
if a.transferBytesExceeded(period.rx, period.tx, bwc) {
|
||||
selectedLc = bwc
|
||||
rxBytes = period.rx
|
||||
txBytes = period.tx
|
||||
@ -489,24 +484,7 @@ func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.Bandwid
|
||||
return selectedLc, rxBytes, txBytes, nil
|
||||
}
|
||||
|
||||
func (a *Agent) bandwidthClassPoints(bwc store.BandwidthClass) int {
|
||||
points := 0
|
||||
if !bwc.IsGlobal() {
|
||||
points++
|
||||
}
|
||||
if bwc.GetLimitAction() == store.WarningLimitAction {
|
||||
points++
|
||||
}
|
||||
if bwc.GetLimitAction() == store.LimitLimitAction {
|
||||
points += 2
|
||||
}
|
||||
if bwc.GetBackendMode() != "" {
|
||||
points += 10
|
||||
}
|
||||
return points
|
||||
}
|
||||
|
||||
func (a *Agent) limitExceeded(rx, tx int64, bwc store.BandwidthClass) bool {
|
||||
func (a *Agent) transferBytesExceeded(rx, tx int64, bwc store.BandwidthClass) bool {
|
||||
if bwc.GetTxBytes() != store.Unlimited && tx >= bwc.GetTxBytes() {
|
||||
return true
|
||||
}
|
||||
|
@ -106,13 +106,6 @@ func (str *Store) FindAllLatestBandwidthLimitJournal(trx *sqlx.Tx) ([]*Bandwidth
|
||||
return jes, nil
|
||||
}
|
||||
|
||||
func (str *Store) DeleteBandwidthLimitJournal(acctId int, trx *sqlx.Tx) error {
|
||||
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1", acctId); err != nil {
|
||||
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d", acctId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (str *Store) DeleteBandwidthLimitJournalEntryForGlobal(acctId int, trx *sqlx.Tx) error {
|
||||
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1 and limit_class_id is null", acctId); err != nil {
|
||||
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d and limit_class_id is null", acctId)
|
||||
|
Loading…
Reference in New Issue
Block a user