diff --git a/controller/limits/agent.go b/controller/limits/agent.go index 16d24302..3a3be465 100644 --- a/controller/limits/agent.go +++ b/controller/limits/agent.go @@ -313,17 +313,17 @@ func (a *Agent) enforce(u *metrics.Usage) error { return nil } - //shr, err := a.str.FindShareWithTokenEvenIfDeleted(u.ShareToken, trx) - //if err != nil { - // return err - //} + shr, err := a.str.FindShareWithTokenEvenIfDeleted(u.ShareToken, trx) + if err != nil { + return err + } ul, err := a.getUserLimits(int(u.AccountId), trx) if err != nil { return err } - exceededLc, rxBytes, txBytes, err := a.isOverLimitClass(u, ul.bandwidth) + exceededLc, rxBytes, txBytes, err := a.isOverLimitClass(u, ul.toBandwidthArray(sdk.BackendMode(shr.BackendMode))) if err != nil { return errors.Wrap(err, "error checking limit classes") } diff --git a/controller/limits/userLimits.go b/controller/limits/userLimits.go index 97513c20..bc654160 100644 --- a/controller/limits/userLimits.go +++ b/controller/limits/userLimits.go @@ -14,6 +14,18 @@ type userLimits struct { scopes map[sdk.BackendMode]store.BandwidthClass } +func (ul *userLimits) toBandwidthArray(backendMode sdk.BackendMode) []store.BandwidthClass { + if scopedBwc, found := ul.scopes[backendMode]; found { + out := make([]store.BandwidthClass, 0) + for _, bwc := range ul.bandwidth { + out = append(out, bwc) + } + out = append(out, scopedBwc) + return out + } + return ul.bandwidth +} + func (a *Agent) getUserLimits(acctId int, trx *sqlx.Tx) (*userLimits, error) { resource := newConfigResourceCountClass(a.cfg) cfgBwcs := newConfigBandwidthClasses(a.cfg.Bandwidth)