mirror of
https://github.com/openziti/zrok.git
synced 2025-06-24 03:31:57 +02:00
fix for bandwidth limit check in CanCreateShare; improved code cleanliness (#606)
This commit is contained in:
parent
26acd4f5a6
commit
aee973379c
@ -87,25 +87,14 @@ func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, _ s
|
|||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ul.resource.IsGlobal() {
|
bwcs := ul.toBandwidthArray(backendMode)
|
||||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(acctId, trx); err == nil && !empty {
|
for _, bwc := range bwcs {
|
||||||
lj, err := a.str.FindLatestBandwidthLimitJournalForGlobal(acctId, trx)
|
latestJe, err := a.isBandwidthClassLimitedForAccount(acctId, bwc, trx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
|
||||||
if lj.Action == store.LimitLimitAction {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
if latestJe != nil {
|
||||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, ul.resource.GetLimitClassId(), trx); err == nil && !empty {
|
return false, nil
|
||||||
lj, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, ul.resource.GetLimitClassId(), trx)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if lj.Action == store.LimitLimitAction {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -261,50 +250,25 @@ func (a *Agent) enforce(u *metrics.Usage) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
exceededLc, rxBytes, txBytes, err := a.hasExceededBandwidthLimit(u, ul.toBandwidthArray(sdk.BackendMode(shr.BackendMode)))
|
exceededBwc, rxBytes, txBytes, err := a.anyBandwidthLimitExceeded(u, ul.toBandwidthArray(sdk.BackendMode(shr.BackendMode)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "error checking limit classes")
|
return errors.Wrap(err, "error checking limit classes")
|
||||||
}
|
}
|
||||||
|
|
||||||
if exceededLc != nil {
|
if exceededBwc != nil {
|
||||||
enforced := false
|
latestJe, err := a.isBandwidthClassLimitedForAccount(int(u.AccountId), exceededBwc, trx)
|
||||||
var enforcedAt time.Time
|
if err != nil {
|
||||||
|
return err
|
||||||
if exceededLc.IsGlobal() {
|
|
||||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(int(u.AccountId), trx); err == nil && !empty {
|
|
||||||
if latest, err := a.str.FindLatestBandwidthLimitJournalForGlobal(int(u.AccountId), trx); err == nil {
|
|
||||||
enforced = latest.Action == exceededLc.GetLimitAction()
|
|
||||||
enforcedAt = latest.UpdatedAt
|
|
||||||
logrus.Debugf("limit '%v' already applied (enforced: %t)", exceededLc, enforced)
|
|
||||||
} else {
|
|
||||||
logrus.Errorf("error getting latest global bandwidth journal entry: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logrus.Debugf("no bandwidth limit journal entry for '%v'", exceededLc)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil && !empty {
|
|
||||||
if latest, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil {
|
|
||||||
enforced = latest.Action == exceededLc.GetLimitAction()
|
|
||||||
enforcedAt = latest.UpdatedAt
|
|
||||||
logrus.Debugf("limit '%v' already applied (enforced: %t)", exceededLc, enforced)
|
|
||||||
} else {
|
|
||||||
logrus.Errorf("error getting latest bandwidth limit journal entry for limit class '%d': %v", exceededLc.GetLimitClassId(), err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logrus.Debugf("no bandwidth limit journal entry for '%v'", exceededLc)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if latestJe == nil {
|
||||||
if !enforced {
|
|
||||||
je := &store.BandwidthLimitJournalEntry{
|
je := &store.BandwidthLimitJournalEntry{
|
||||||
AccountId: int(u.AccountId),
|
AccountId: int(u.AccountId),
|
||||||
RxBytes: rxBytes,
|
RxBytes: rxBytes,
|
||||||
TxBytes: txBytes,
|
TxBytes: txBytes,
|
||||||
Action: exceededLc.GetLimitAction(),
|
Action: exceededBwc.GetLimitAction(),
|
||||||
}
|
}
|
||||||
if !exceededLc.IsGlobal() {
|
if !exceededBwc.IsGlobal() {
|
||||||
lcId := exceededLc.GetLimitClassId()
|
lcId := exceededBwc.GetLimitClassId()
|
||||||
je.LimitClassId = &lcId
|
je.LimitClassId = &lcId
|
||||||
}
|
}
|
||||||
if _, err := a.str.CreateBandwidthLimitJournalEntry(je, trx); err != nil {
|
if _, err := a.str.CreateBandwidthLimitJournalEntry(je, trx); err != nil {
|
||||||
@ -314,17 +278,17 @@ func (a *Agent) enforce(u *metrics.Usage) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
switch exceededLc.GetLimitAction() {
|
switch exceededBwc.GetLimitAction() {
|
||||||
case store.LimitLimitAction:
|
case store.LimitLimitAction:
|
||||||
for _, limitAction := range a.limitActions {
|
for _, limitAction := range a.limitActions {
|
||||||
if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, ul, trx); err != nil {
|
if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededBwc, ul, trx); err != nil {
|
||||||
return errors.Wrapf(err, "%v", reflect.TypeOf(limitAction).String())
|
return errors.Wrapf(err, "%v", reflect.TypeOf(limitAction).String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case store.WarningLimitAction:
|
case store.WarningLimitAction:
|
||||||
for _, warningAction := range a.warningActions {
|
for _, warningAction := range a.warningActions {
|
||||||
if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, ul, trx); err != nil {
|
if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededBwc, ul, trx); err != nil {
|
||||||
return errors.Wrapf(err, "%v", reflect.TypeOf(warningAction).String())
|
return errors.Wrapf(err, "%v", reflect.TypeOf(warningAction).String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -333,7 +297,7 @@ func (a *Agent) enforce(u *metrics.Usage) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logrus.Debugf("already enforced limit for account '%d' at %v", u.AccountId, enforcedAt)
|
logrus.Debugf("limit '%v' already applied for '%v' (at: %v)", exceededBwc, acct.Email, latestJe.CreatedAt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -406,7 +370,7 @@ func (a *Agent) relax() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
used := periodBw[bwc.GetPeriodMinutes()]
|
used := periodBw[bwc.GetPeriodMinutes()]
|
||||||
if !a.limitExceeded(used.rx, used.tx, bwc) {
|
if !a.transferBytesExceeded(used.rx, used.tx, bwc) {
|
||||||
if bwc.GetLimitAction() == store.LimitLimitAction {
|
if bwc.GetLimitAction() == store.LimitLimitAction {
|
||||||
logrus.Infof("relaxing limit '%v' for '%v'", bwc.String(), accounts[bwje.AccountId].Email)
|
logrus.Infof("relaxing limit '%v' for '%v'", bwc.String(), accounts[bwje.AccountId].Email)
|
||||||
for _, action := range a.relaxActions {
|
for _, action := range a.relaxActions {
|
||||||
@ -447,7 +411,38 @@ func (a *Agent) relax() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.BandwidthClass) (store.BandwidthClass, int64, int64, error) {
|
func (a *Agent) isBandwidthClassLimitedForAccount(acctId int, bwc store.BandwidthClass, trx *sqlx.Tx) (*store.BandwidthLimitJournalEntry, error) {
|
||||||
|
if bwc.IsGlobal() {
|
||||||
|
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(acctId, trx); err == nil && !empty {
|
||||||
|
je, err := a.str.FindLatestBandwidthLimitJournalForGlobal(acctId, trx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if je.Action == store.LimitLimitAction {
|
||||||
|
logrus.Infof("account '#%d' over bandwidth for global bandwidth class '%v'", acctId, bwc)
|
||||||
|
return je, nil
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, bwc.GetLimitClassId(), trx); err == nil && !empty {
|
||||||
|
je, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, bwc.GetLimitClassId(), trx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if je.Action == store.LimitLimitAction {
|
||||||
|
logrus.Infof("account '#%d' over bandwidth for limit class '%v'", acctId, bwc)
|
||||||
|
return je, nil
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) anyBandwidthLimitExceeded(u *metrics.Usage, bwcs []store.BandwidthClass) (store.BandwidthClass, int64, int64, error) {
|
||||||
periodBw := make(map[int]struct {
|
periodBw := make(map[int]struct {
|
||||||
rx int64
|
rx int64
|
||||||
tx int64
|
tx int64
|
||||||
@ -473,7 +468,7 @@ func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.Bandwid
|
|||||||
}
|
}
|
||||||
period := periodBw[bwc.GetPeriodMinutes()]
|
period := periodBw[bwc.GetPeriodMinutes()]
|
||||||
|
|
||||||
if a.limitExceeded(period.rx, period.tx, bwc) {
|
if a.transferBytesExceeded(period.rx, period.tx, bwc) {
|
||||||
selectedLc = bwc
|
selectedLc = bwc
|
||||||
rxBytes = period.rx
|
rxBytes = period.rx
|
||||||
txBytes = period.tx
|
txBytes = period.tx
|
||||||
@ -489,24 +484,7 @@ func (a *Agent) hasExceededBandwidthLimit(u *metrics.Usage, bwcs []store.Bandwid
|
|||||||
return selectedLc, rxBytes, txBytes, nil
|
return selectedLc, rxBytes, txBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) bandwidthClassPoints(bwc store.BandwidthClass) int {
|
func (a *Agent) transferBytesExceeded(rx, tx int64, bwc store.BandwidthClass) bool {
|
||||||
points := 0
|
|
||||||
if !bwc.IsGlobal() {
|
|
||||||
points++
|
|
||||||
}
|
|
||||||
if bwc.GetLimitAction() == store.WarningLimitAction {
|
|
||||||
points++
|
|
||||||
}
|
|
||||||
if bwc.GetLimitAction() == store.LimitLimitAction {
|
|
||||||
points += 2
|
|
||||||
}
|
|
||||||
if bwc.GetBackendMode() != "" {
|
|
||||||
points += 10
|
|
||||||
}
|
|
||||||
return points
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Agent) limitExceeded(rx, tx int64, bwc store.BandwidthClass) bool {
|
|
||||||
if bwc.GetTxBytes() != store.Unlimited && tx >= bwc.GetTxBytes() {
|
if bwc.GetTxBytes() != store.Unlimited && tx >= bwc.GetTxBytes() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -106,13 +106,6 @@ func (str *Store) FindAllLatestBandwidthLimitJournal(trx *sqlx.Tx) ([]*Bandwidth
|
|||||||
return jes, nil
|
return jes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (str *Store) DeleteBandwidthLimitJournal(acctId int, trx *sqlx.Tx) error {
|
|
||||||
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1", acctId); err != nil {
|
|
||||||
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d", acctId)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (str *Store) DeleteBandwidthLimitJournalEntryForGlobal(acctId int, trx *sqlx.Tx) error {
|
func (str *Store) DeleteBandwidthLimitJournalEntryForGlobal(acctId int, trx *sqlx.Tx) error {
|
||||||
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1 and limit_class_id is null", acctId); err != nil {
|
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1 and limit_class_id is null", acctId); err != nil {
|
||||||
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d and limit_class_id is null", acctId)
|
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d and limit_class_id is null", acctId)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user