diff --git a/controller/limits/agent.go b/controller/limits/agent.go index b8a817e0..8f9ec89a 100644 --- a/controller/limits/agent.go +++ b/controller/limits/agent.go @@ -76,34 +76,18 @@ func (a *Agent) CanCreateEnvironment(acctId int, trx *sqlx.Tx) (bool, error) { return true, nil } -func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, shareMode sdk.ShareMode, backendMode sdk.BackendMode, trx *sqlx.Tx) (bool, error) { +func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, _ sdk.ShareMode, backendMode sdk.BackendMode, trx *sqlx.Tx) (bool, error) { if a.cfg.Enforcing { if err := a.str.LimitCheckLock(acctId, trx); err != nil { return false, err } - alcs, err := a.str.FindAppliedLimitClassesForAccount(acctId, trx) + ul, err := a.getUserLimits(acctId, trx) if err != nil { return false, err } - maxShares := a.cfg.Shares - maxReservedShares := a.cfg.ReservedShares - maxUniqueNames := a.cfg.UniqueNames - var lcId *int - var points = -1 - for _, alc := range alcs { - if a.bandwidthClassPoints(alc) > points { - if alc.Shares >= maxShares || alc.ReservedShares >= maxReservedShares || alc.UniqueNames >= maxUniqueNames { - maxShares = alc.Shares - maxReservedShares = alc.ReservedShares - maxUniqueNames = alc.UniqueNames - lcId = &alc.Id - points = a.bandwidthClassPoints(alc) - } - } - } - if lcId == nil { + if ul.resource.IsGlobal() { if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(acctId, trx); err == nil && !empty { lj, err := a.str.FindLatestBandwidthLimitJournalForGlobal(acctId, trx) if err != nil { @@ -114,8 +98,8 @@ func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, sha } } } else { - if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, *lcId, trx); err == nil && !empty { - lj, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, *lcId, trx) + if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, ul.resource.GetLimitClassId(), trx); err == nil && !empty { + lj, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, ul.resource.GetLimitClassId(), trx) if err != nil { return false, err } @@ -125,7 +109,7 @@ func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, sha } } - if maxShares > store.Unlimited || (reserved && maxReservedShares > store.Unlimited) || (reserved && uniqueName && maxUniqueNames > store.Unlimited) { + if ul.resource.GetShares() > store.Unlimited || (reserved && ul.resource.GetReservedShares() > store.Unlimited) || (reserved && uniqueName && ul.resource.GetUniqueNames() > store.Unlimited) { envs, err := a.str.FindEnvironmentsForAccount(acctId, trx) if err != nil { return false, err @@ -147,15 +131,15 @@ func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, sha uniqueNames++ } } - if total+1 > a.cfg.Shares { + if total+1 > ul.resource.GetShares() { logrus.Debugf("account '%d', environment '%d' over shares limit '%d'", acctId, envId, a.cfg.Shares) return false, nil } - if reserved && reserveds+1 > a.cfg.ReservedShares { + if reserved && reserveds+1 > ul.resource.GetReservedShares() { logrus.Debugf("account '%v', environment '%d' over reserved shares limit '%d'", acctId, envId, a.cfg.ReservedShares) return false, nil } - if reserved && uniqueName && uniqueNames+1 > a.cfg.UniqueNames { + if reserved && uniqueName && uniqueNames+1 > ul.resource.GetUniqueNames() { logrus.Debugf("account '%v', environment '%d' over unique names limit '%d'", acctId, envId, a.cfg.UniqueNames) return false, nil } @@ -177,28 +161,12 @@ func (a *Agent) CanAccessShare(shrId int, trx *sqlx.Tx) (bool, error) { return false, err } if env.AccountId != nil { - alcs, err := a.str.FindAppliedLimitClassesForAccount(*env.AccountId, trx) + ul, err := a.getUserLimits(*env.AccountId, trx) if err != nil { return false, err } - maxShares := a.cfg.Shares - maxReservedShares := a.cfg.ReservedShares - maxUniqueNames := a.cfg.UniqueNames - var lcId *int - var points = -1 - for _, alc := range alcs { - if a.bandwidthClassPoints(alc) > points { - if alc.Shares >= maxShares || alc.ReservedShares >= maxReservedShares || alc.UniqueNames >= maxUniqueNames { - maxShares = alc.Shares - maxReservedShares = alc.ReservedShares - maxUniqueNames = alc.UniqueNames - lcId = &alc.Id - points = a.bandwidthClassPoints(alc) - } - } - } - if lcId == nil { + if ul.resource.IsGlobal() { if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(*env.AccountId, trx); err == nil && !empty { lj, err := a.str.FindLatestBandwidthLimitJournalForGlobal(*env.AccountId, trx) if err != nil { @@ -209,8 +177,8 @@ func (a *Agent) CanAccessShare(shrId int, trx *sqlx.Tx) (bool, error) { } } } else { - if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(*env.AccountId, *lcId, trx); err == nil && !empty { - lj, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(*env.AccountId, *lcId, trx) + if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(*env.AccountId, ul.resource.GetLimitClassId(), trx); err == nil && !empty { + lj, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(*env.AccountId, ul.resource.GetLimitClassId(), trx) if err != nil { return false, err }