From 0fa1a350cde19386cccde2b7b5966847fabc5131 Mon Sep 17 00:00:00 2001 From: Michael Quigley Date: Tue, 6 Jun 2023 11:29:22 -0400 Subject: [PATCH] limits.CanAccessShare --- controller/access.go | 10 +++---- controller/limits/agent.go | 53 +++++++++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/controller/access.go b/controller/access.go index 7e8b3dbf..56f9c8c5 100644 --- a/controller/access.go +++ b/controller/access.go @@ -57,7 +57,7 @@ func (h *accessHandler) Handle(params share.AccessParams, principal *rest_model_ return share.NewAccessNotFound() } - if err := h.checkLimits(shrToken, trx); err != nil { + if err := h.checkLimits(shr, trx); err != nil { logrus.Errorf("cannot access limited share for '%v': %v", principal.Email, err) return share.NewAccessNotFound() } @@ -99,14 +99,14 @@ func (h *accessHandler) Handle(params share.AccessParams, principal *rest_model_ }) } -func (h *accessHandler) checkLimits(shrToken string, trx *sqlx.Tx) error { +func (h *accessHandler) checkLimits(shr *store.Share, trx *sqlx.Tx) error { if limitsAgent != nil { - ok, err := limitsAgent.CanAccessShare(shrToken, trx) + ok, err := limitsAgent.CanAccessShare(shr.Id, trx) if err != nil { - return errors.Wrapf(err, "error checking share limits for '%v'", shrToken) + return errors.Wrapf(err, "error checking share limits for '%v'", shr.Token) } if !ok { - return errors.Errorf("share limit check failed for '%v'", shrToken) + return errors.Errorf("share limit check failed for '%v'", shr.Token) } } return nil diff --git a/controller/limits/agent.go b/controller/limits/agent.go index b67e9839..81c5e63b 100644 --- a/controller/limits/agent.go +++ b/controller/limits/agent.go @@ -139,7 +139,58 @@ func (a *Agent) CanCreateShare(acctId, envId int, trx *sqlx.Tx) (bool, error) { return true, nil } -func (a *Agent) CanAccessShare(shrToken string, trx *sqlx.Tx) (bool, error) { +func (a *Agent) CanAccessShare(shrId int, trx *sqlx.Tx) (bool, error) { + if a.cfg.Enforcing { + shr, err := a.str.GetShare(shrId, trx) + if err != nil { + return false, err + } + if empty, err := a.str.IsShareLimitJournalEmpty(shr.Id, trx); err == nil && !empty { + slj, err := a.str.FindLatestShareLimitJournal(shr.Id, trx) + if err != nil { + return false, err + } + if slj.Action == store.LimitAction { + return false, nil + } + } else if err != nil { + return false, err + } + + env, err := a.str.GetEnvironment(shr.EnvironmentId, trx) + if err != nil { + return false, err + } + if empty, err := a.str.IsEnvironmentLimitJournalEmpty(env.Id, trx); err == nil && !empty { + elj, err := a.str.FindLatestEnvironmentLimitJournal(env.Id, trx) + if err != nil { + return false, err + } + if elj.Action == store.LimitAction { + return false, nil + } + } else if err != nil { + return false, err + } + + if env.AccountId != nil { + acct, err := a.str.GetAccount(*env.AccountId, trx) + if err != nil { + return false, err + } + if empty, err := a.str.IsAccountLimitJournalEmpty(acct.Id, trx); err == nil && !empty { + alj, err := a.str.FindLatestAccountLimitJournal(acct.Id, trx) + if err != nil { + return false, err + } + if alj.Action == store.LimitAction { + return false, nil + } + } else if err != nil { + return false, err + } + } + } return true, nil }