diff --git a/controller/limits/agent.go b/controller/limits/agent.go index 9dd9332e..619f6e28 100644 --- a/controller/limits/agent.go +++ b/controller/limits/agent.go @@ -94,7 +94,7 @@ func (a *Agent) CanCreateEnvironment(acctId int, trx *sqlx.Tx) (bool, error) { return true, nil } -func (a *Agent) CanCreateShare(acctId, envId int, trx *sqlx.Tx) (bool, error) { +func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, trx *sqlx.Tx) (bool, error) { if a.cfg.Enforcing { if err := a.str.LimitCheckLock(acctId, trx); err != nil { return false, err @@ -123,21 +123,37 @@ func (a *Agent) CanCreateShare(acctId, envId int, trx *sqlx.Tx) (bool, error) { return false, err } - if a.cfg.Shares > Unlimited { + if a.cfg.Shares > Unlimited || (reserved && a.cfg.ReservedShares > Unlimited) || (reserved && uniqueName && a.cfg.UniqueNames > Unlimited) { envs, err := a.str.FindEnvironmentsForAccount(acctId, trx) if err != nil { return false, err } total := 0 + reserveds := 0 + uniqueNames := 0 for i := range envs { shrs, err := a.str.FindSharesForEnvironment(envs[i].Id, trx) if err != nil { return false, errors.Wrapf(err, "unable to find shares for environment '%v'", envs[i].ZId) } total += len(shrs) + for _, shr := range shrs { + if shr.Reserved { + reserveds++ + } + if shr.UniqueName { + uniqueNames++ + } + } if total+1 > a.cfg.Shares { return false, nil } + if reserved && reserveds+1 > a.cfg.ReservedShares { + return false, nil + } + if reserved && uniqueName && uniqueNames+1 > a.cfg.UniqueNames { + return false, nil + } logrus.Infof("total = %d", total) } } diff --git a/controller/share.go b/controller/share.go index b2daabf7..7140be3e 100644 --- a/controller/share.go +++ b/controller/share.go @@ -49,7 +49,7 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr return share.NewShareInternalServerError() } - if err := h.checkLimits(envId, principal, trx); err != nil { + if err := h.checkLimits(envId, principal, params.Body.Reserved, params.Body.UniqueName != "", trx); err != nil { logrus.Errorf("limits error: %v", err) return share.NewShareUnauthorized() } @@ -190,10 +190,10 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr }) } -func (h *shareHandler) checkLimits(envId int, principal *rest_model_zrok.Principal, trx *sqlx.Tx) error { +func (h *shareHandler) checkLimits(envId int, principal *rest_model_zrok.Principal, reserved, uniqueName bool, trx *sqlx.Tx) error { if !principal.Limitless { if limitsAgent != nil { - ok, err := limitsAgent.CanCreateShare(int(principal.ID), envId, trx) + ok, err := limitsAgent.CanCreateShare(int(principal.ID), envId, reserved, uniqueName, trx) if err != nil { return errors.Wrapf(err, "error checking share limits for '%v'", principal.Email) }