From 60893100a573dbb0b311492342d796251e634da6 Mon Sep 17 00:00:00 2001 From: Michael Quigley Date: Fri, 7 Jun 2024 14:02:28 -0400 Subject: [PATCH] fix for not relaxing scoped bandwidth clases that are in a limited state (#606) --- controller/limits/relaxAction.go | 21 +++++++++++++++++++-- controller/store/bandwidthLimitJournal.go | 16 ++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/controller/limits/relaxAction.go b/controller/limits/relaxAction.go index 454b03e0..918090f9 100644 --- a/controller/limits/relaxAction.go +++ b/controller/limits/relaxAction.go @@ -27,6 +27,23 @@ func (a *relaxAction) HandleAccount(acct *store.Account, _, _ int64, bwc store.B return errors.Wrapf(err, "error finding environments for account '%v'", acct.Email) } + jes, err := a.str.FindAllLatestBandwidthLimitJournalForAccount(acct.Id, trx) + if err != nil { + return errors.Wrapf(err, "error finding latest bandwidth limit journal entries for account '%v'", acct.Email) + } + limitedBackends := make(map[sdk.BackendMode]bool) + for _, je := range jes { + if je.LimitClassId != nil { + lc, err := a.str.GetLimitClass(*je.LimitClassId, trx) + if err != nil { + return err + } + if lc.BackendMode != nil && lc.LimitAction == store.LimitLimitAction { + limitedBackends[*lc.BackendMode] = true + } + } + } + edge, err := zrokEdgeSdk.Client(a.zCfg) if err != nil { return err @@ -39,8 +56,8 @@ func (a *relaxAction) HandleAccount(acct *store.Account, _, _ int64, bwc store.B } for _, shr := range shrs { - // TODO: when relaxing unscoped classes; need to not relax other scoped limits - if !bwc.IsScoped() || bwc.GetBackendMode() == sdk.BackendMode(shr.BackendMode) { + _, stayLimited := limitedBackends[sdk.BackendMode(shr.BackendMode)] + if (!bwc.IsScoped() && !stayLimited) || bwc.GetBackendMode() == sdk.BackendMode(shr.BackendMode) { switch shr.ShareMode { case string(sdk.PublicShareMode): if err := relaxPublicShare(a.str, edge, shr, trx); err != nil { diff --git a/controller/store/bandwidthLimitJournal.go b/controller/store/bandwidthLimitJournal.go index 5b867a36..b38f8429 100644 --- a/controller/store/bandwidthLimitJournal.go +++ b/controller/store/bandwidthLimitJournal.go @@ -90,6 +90,22 @@ func (str *Store) FindAllBandwidthLimitJournal(trx *sqlx.Tx) ([]*BandwidthLimitJ return jes, nil } +func (str *Store) FindAllLatestBandwidthLimitJournalForAccount(acctId int, trx *sqlx.Tx) ([]*BandwidthLimitJournalEntry, error) { + rows, err := trx.Queryx("select * from bandwidth_limit_journal where account_id = $1", acctId) + if err != nil { + return nil, errors.Wrap(err, "error finding all for account from bandwidth_limit_journal") + } + var jes []*BandwidthLimitJournalEntry + for rows.Next() { + je := &BandwidthLimitJournalEntry{} + if err := rows.StructScan(je); err != nil { + return nil, errors.Wrap(err, "error scanning bandwidth_limit_journal") + } + jes = append(jes, je) + } + return jes, nil +} + func (str *Store) FindAllLatestBandwidthLimitJournal(trx *sqlx.Tx) ([]*BandwidthLimitJournalEntry, error) { rows, err := trx.Queryx("select id, account_id, limit_class_id, action, rx_bytes, tx_bytes, created_at, updated_at from bandwidth_limit_journal where id in (select max(id) as id from bandwidth_limit_journal group by account_id)") if err != nil {