diff --git a/controller/limits/agent.go b/controller/limits/agent.go index 29dc828d..16d24302 100644 --- a/controller/limits/agent.go +++ b/controller/limits/agent.go @@ -313,17 +313,17 @@ func (a *Agent) enforce(u *metrics.Usage) error { return nil } - shr, err := a.str.FindShareWithTokenEvenIfDeleted(u.ShareToken, trx) - if err != nil { - return err - } - logrus.Debugf("share: '%v', shareMode: '%v', backendMode: '%v'", shr.Token, shr.ShareMode, shr.BackendMode) + //shr, err := a.str.FindShareWithTokenEvenIfDeleted(u.ShareToken, trx) + //if err != nil { + // return err + //} - alcs, err := a.str.FindAppliedLimitClassesForAccount(int(u.AccountId), trx) + ul, err := a.getUserLimits(int(u.AccountId), trx) if err != nil { return err } - exceededLc, rxBytes, txBytes, err := a.isOverLimitClass(u, alcs) + + exceededLc, rxBytes, txBytes, err := a.isOverLimitClass(u, ul.bandwidth) if err != nil { return errors.Wrap(err, "error checking limit classes") } @@ -504,22 +504,17 @@ func (a *Agent) relax() error { return nil } -func (a *Agent) isOverLimitClass(u *metrics.Usage, alcs []*store.LimitClass) (store.BandwidthClass, int64, int64, error) { +func (a *Agent) isOverLimitClass(u *metrics.Usage, bwcs []store.BandwidthClass) (store.BandwidthClass, int64, int64, error) { periodBw := make(map[int]struct { rx int64 tx int64 }) - var allBwcs []store.BandwidthClass - for _, alc := range alcs { - allBwcs = append(allBwcs, alc) - } - for _, globBwc := range newConfigBandwidthClasses(a.cfg.Bandwidth) { - allBwcs = append(allBwcs, globBwc) - } + var selectedLc store.BandwidthClass + var rxBytes int64 + var txBytes int64 - // find period data for each class - for _, bwc := range allBwcs { + for _, bwc := range bwcs { if _, found := periodBw[bwc.GetPeriodMinutes()]; !found { rx, tx, err := a.ifx.totalRxTxForAccount(u.AccountId, time.Minute*time.Duration(bwc.GetPeriodMinutes())) if err != nil { @@ -533,27 +528,19 @@ func (a *Agent) isOverLimitClass(u *metrics.Usage, alcs []*store.LimitClass) (st tx: tx, } } + period := periodBw[bwc.GetPeriodMinutes()] + + if a.limitExceeded(period.rx, period.tx, bwc) { + selectedLc = bwc + rxBytes = period.rx + txBytes = period.tx + } else { + logrus.Debugf("limit ok '%v' with rx: %d, tx: %d, total: %d", bwc, period.rx, period.tx, period.rx+period.tx) + } } - // find the highest, most specific limit class that has been exceeded - var selectedLc store.BandwidthClass - selectedLcPoints := -1 - var rxBytes int64 - var txBytes int64 - for _, bwc := range allBwcs { - points := a.bandwidthClassPoints(bwc) - if points >= selectedLcPoints { - period := periodBw[bwc.GetPeriodMinutes()] - if a.limitExceeded(period.rx, period.tx, bwc) { - selectedLc = bwc - selectedLcPoints = points - rxBytes = period.rx - txBytes = period.tx - logrus.Debugf("exceeded limit '%v' with rx: %d, tx: %d", bwc.String(), period.rx, period.tx) - } else { - logrus.Debugf("limit '%v' ok with rx: %d, tx: %d", bwc.String(), period.rx, period.tx) - } - } + if selectedLc != nil { + logrus.Infof("exceeded limit '%v' with rx: %d, tx: %d, total: %d", selectedLc, rxBytes, txBytes, rxBytes+txBytes) } return selectedLc, rxBytes, txBytes, nil