From 97dbd197d64388553a89e86b4423ade84785bc3b Mon Sep 17 00:00:00 2001 From: Michael Quigley Date: Tue, 4 Jun 2024 14:06:44 -0400 Subject: [PATCH] massive bandwidth limits rewrite to support limit classes (#606) --- controller/limits/agent.go | 337 +++++++++++++--------- controller/limits/limitAction.go | 2 +- controller/limits/limitClasses.go | 4 + controller/limits/model.go | 6 +- controller/limits/relaxAction.go | 2 +- controller/limits/warningAction.go | 19 +- controller/store/appliedLimitClass.go | 2 +- controller/store/bandwidthLimitJournal.go | 55 ++++ controller/store/limitClass.go | 31 +- controller/store/share.go | 8 + 10 files changed, 314 insertions(+), 152 deletions(-) diff --git a/controller/limits/agent.go b/controller/limits/agent.go index 504da8fe..86da6240 100644 --- a/controller/limits/agent.go +++ b/controller/limits/agent.go @@ -217,87 +217,82 @@ func (a *Agent) enforce(u *metrics.Usage) error { return nil } - shr, err := a.str.FindShareWithToken(u.ShareToken, trx) + shr, err := a.str.FindShareWithTokenEvenIfDeleted(u.ShareToken, trx) if err != nil { return err } - logrus.Infof("share: '%v', shareMode: '%v', backendMode: '%v'", shr.Token, shr.ShareMode, shr.BackendMode) + logrus.Debugf("share: '%v', shareMode: '%v', backendMode: '%v'", shr.Token, shr.ShareMode, shr.BackendMode) - if enforce, warning, rxBytes, txBytes, err := a.checkBandwidthLimit(u.AccountId); err == nil { - if enforce { - enforced := false - var enforcedAt time.Time - if empty, err := a.str.IsBandwidthLimitJournalEmpty(int(u.AccountId), trx); err == nil && !empty { - if latest, err := a.str.FindLatestBandwidthLimitJournal(int(u.AccountId), trx); err == nil { - enforced = latest.Action == store.LimitLimitAction + alcs, err := a.str.FindAppliedLimitClassesForAccount(int(u.AccountId), trx) + if err != nil { + return err + } + exceededLc, rxBytes, txBytes, err := a.isOverLimitClass(u, alcs) + if err != nil { + return errors.Wrap(err, "error checking limit classes") + } + + if exceededLc != nil { + enforced := false + var enforcedAt time.Time + + if exceededLc.IsGlobal() { + if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(int(u.AccountId), trx); err == nil && !empty { + if latest, err := a.str.FindLatestBandwidthLimitJournalForGlobal(int(u.AccountId), trx); err == nil { + enforced = latest.Action == exceededLc.GetLimitAction() enforcedAt = latest.UpdatedAt } } - - if !enforced { - _, err := a.str.CreateBandwidthLimitJournalEntry(&store.BandwidthLimitJournalEntry{ - AccountId: int(u.AccountId), - RxBytes: rxBytes, - TxBytes: txBytes, - Action: store.LimitLimitAction, - }, trx) - if err != nil { - return err + } else { + if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil && !empty { + if latest, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil { + enforced = latest.Action == exceededLc.GetLimitAction() + enforcedAt = latest.UpdatedAt } - acct, err := a.str.GetAccount(int(u.AccountId), trx) - if err != nil { - return err - } - for _, action := range a.limitActions { - if err := action.HandleAccount(acct, rxBytes, txBytes, a.cfg.Bandwidth, trx); err != nil { - return errors.Wrapf(err, "%v", reflect.TypeOf(action).String()) - } - } - if err := trx.Commit(); err != nil { - return err - } - } else { - logrus.Debugf("already enforced limit for account '#%d' at %v", u.AccountId, enforcedAt) - } - - } else if warning { - warned := false - var warnedAt time.Time - if empty, err := a.str.IsBandwidthLimitJournalEmpty(int(u.AccountId), trx); err == nil && !empty { - if latest, err := a.str.FindLatestBandwidthLimitJournal(int(u.AccountId), trx); err == nil { - warned = latest.Action == store.WarningLimitAction || latest.Action == store.LimitLimitAction - warnedAt = latest.UpdatedAt - } - } - - if !warned { - _, err := a.str.CreateBandwidthLimitJournalEntry(&store.BandwidthLimitJournalEntry{ - AccountId: int(u.AccountId), - RxBytes: rxBytes, - TxBytes: txBytes, - Action: store.WarningLimitAction, - }, trx) - if err != nil { - return err - } - acct, err := a.str.GetAccount(int(u.AccountId), trx) - if err != nil { - return err - } - for _, action := range a.warningActions { - if err := action.HandleAccount(acct, rxBytes, txBytes, a.cfg.Bandwidth, trx); err != nil { - return errors.Wrapf(err, "%v", reflect.TypeOf(action).String()) - } - } - if err := trx.Commit(); err != nil { - return err - } - } else { - logrus.Debugf("already warned account '#%d' at %v", u.AccountId, warnedAt) } } - } else { - logrus.Error(err) + + if !enforced { + je := &store.BandwidthLimitJournalEntry{ + AccountId: int(u.AccountId), + RxBytes: rxBytes, + TxBytes: txBytes, + Action: exceededLc.GetLimitAction(), + } + if !exceededLc.IsGlobal() { + lcId := exceededLc.GetLimitClassId() + je.LimitClassId = &lcId + } + _, err := a.str.CreateBandwidthLimitJournalEntry(je, trx) + + if err != nil { + return err + } + acct, err := a.str.GetAccount(int(u.AccountId), trx) + if err != nil { + return err + } + switch exceededLc.GetLimitAction() { + case store.LimitLimitAction: + for _, limitAction := range a.limitActions { + if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, trx); err != nil { + return errors.Wrapf(err, "%v", reflect.TypeOf(limitAction).String()) + } + } + + case store.WarningLimitAction: + for _, warningAction := range a.warningActions { + if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, trx); err != nil { + return errors.Wrapf(err, "%v", reflect.TypeOf(warningAction).String()) + } + } + } + if err := trx.Commit(); err != nil { + return err + } + } else { + logrus.Debugf("already enforced limit for account '%d' at %v", u.AccountId, enforcedAt) + } } return nil @@ -314,36 +309,77 @@ func (a *Agent) relax() error { commit := false - if aljs, err := a.str.FindAllLatestBandwidthLimitJournal(trx); err == nil { - for _, alj := range aljs { - if acct, err := a.str.GetAccount(alj.AccountId, trx); err == nil { - if alj.Action == store.WarningLimitAction || alj.Action == store.LimitLimitAction { - if enforce, warning, rxBytes, txBytes, err := a.checkBandwidthLimit(int64(alj.AccountId)); err == nil { - if !enforce && !warning { - if alj.Action == store.LimitLimitAction { - // run relax actions for account - for _, action := range a.relaxActions { - if err := action.HandleAccount(acct, rxBytes, txBytes, a.cfg.Bandwidth, trx); err != nil { - return errors.Wrapf(err, "%v", reflect.TypeOf(action).String()) - } - } - } else { - logrus.Infof("relaxing warning for '%v'", acct.Email) - } - if err := a.str.DeleteBandwidthLimitJournal(acct.Id, trx); err == nil { - commit = true - } else { - logrus.Errorf("error deleting account_limit_journal for '%v': %v", acct.Email, err) - } - } else { - logrus.Infof("account '%v' still over limit", acct.Email) - } - } else { - logrus.Errorf("error checking account limit for '%v': %v", acct.Email, err) - } + if bwjes, err := a.str.FindAllBandwidthLimitJournal(trx); err == nil { + periodBw := make(map[int]struct { + rx int64 + tx int64 + }) + + accounts := make(map[int]*store.Account) + + for _, bwje := range bwjes { + if _, found := accounts[bwje.AccountId]; !found { + if acct, err := a.str.GetAccount(bwje.AccountId, trx); err == nil { + accounts[bwje.AccountId] = acct + } else { + return err + } + } + + var bwc store.BandwidthClass + if bwje.LimitClassId != nil { + globalBwcs := newConfigBandwidthClasses(a.cfg.Bandwidth) + if bwje.Action == store.WarningLimitAction { + bwc = globalBwcs[0] + } else { + bwc = globalBwcs[1] } } else { - logrus.Errorf("error getting account for '#%d': %v", alj.AccountId, err) + lc, err := a.str.GetLimitClass(*bwje.LimitClassId, trx) + if err != nil { + return err + } + bwc = lc + } + + if _, found := periodBw[bwc.GetPeriodMinutes()]; !found { + rx, tx, err := a.ifx.totalRxTxForAccount(int64(bwje.AccountId), time.Duration(bwc.GetPeriodMinutes())*time.Minute) + if err != nil { + return err + } + periodBw[bwc.GetPeriodMinutes()] = struct { + rx int64 + tx int64 + }{ + rx: rx, + tx: tx, + } + } + + used := periodBw[bwc.GetPeriodMinutes()] + if !a.limitExceeded(used.rx, used.tx, bwc) { + if bwc.GetLimitAction() == store.LimitLimitAction { + for _, action := range a.relaxActions { + if err := action.HandleAccount(accounts[bwje.AccountId], used.rx, used.tx, bwc, trx); err != nil { + return errors.Wrapf(err, "%v", reflect.TypeOf(action).String()) + } + } + } else { + logrus.Infof("relaxing warning for '%v'", accounts[bwje.AccountId].Email) + } + var lcId *int + if !bwc.IsGlobal() { + newLcId := 0 + newLcId = bwc.GetLimitClassId() + lcId = &newLcId + } + if err := a.str.DeleteBandwidthLimitJournalEntryForLimitClass(bwje.AccountId, lcId, trx); err == nil { + commit = true + } else { + logrus.Errorf("error deleting bandwidth limit journal entry for '%v': %v", accounts[bwje.AccountId].Email, err) + } + } else { + logrus.Infof("account '%v' still over limit: %v", accounts[bwje.AccountId].Email, bwc) } } } else { @@ -359,44 +395,87 @@ func (a *Agent) relax() error { return nil } -func (a *Agent) checkBandwidthLimit(acctId int64) (enforce, warning bool, rxBytes, txBytes int64, err error) { - period := 24 * time.Hour - limit := DefaultBandwidthPerPeriod() - if a.cfg.Bandwidth != nil { - limit = a.cfg.Bandwidth +func (a *Agent) isOverLimitClass(u *metrics.Usage, alcs []*store.LimitClass) (store.BandwidthClass, int64, int64, error) { + periodBw := make(map[int]struct { + rx int64 + tx int64 + }) + + var allBwcs []store.BandwidthClass + for _, alc := range alcs { + allBwcs = append(allBwcs, alc) } - if limit.Period > 0 { - period = limit.Period - } - rx, tx, err := a.ifx.totalRxTxForAccount(acctId, period) - if err != nil { - logrus.Error(err) + for _, globBwc := range newConfigBandwidthClasses(a.cfg.Bandwidth) { + allBwcs = append(allBwcs, globBwc) } - enforce, warning = a.checkLimit(limit, rx, tx) - return enforce, warning, rx, tx, nil + // find period data for each class + for _, bwc := range allBwcs { + if _, found := periodBw[bwc.GetPeriodMinutes()]; !found { + rx, tx, err := a.ifx.totalRxTxForAccount(u.AccountId, time.Minute*time.Duration(bwc.GetPeriodMinutes())) + if err != nil { + return nil, 0, 0, errors.Wrapf(err, "error getting rx/tx for account '%d'", u.AccountId) + } + periodBw[bwc.GetPeriodMinutes()] = struct { + rx int64 + tx int64 + }{ + rx: rx, + tx: tx, + } + } + } + + // find the highest, most specific limit class that has been exceeded + var selectedLc store.BandwidthClass + selectedLcPoints := -1 + var rxBytes int64 + var txBytes int64 + for _, bwc := range allBwcs { + points := a.bandwidthClassPoints(bwc) + if points >= selectedLcPoints { + period := periodBw[bwc.GetPeriodMinutes()] + if a.limitExceeded(period.rx, period.tx, bwc) { + selectedLc = bwc + selectedLcPoints = points + rxBytes = period.rx + txBytes = period.tx + } + } + } + + return selectedLc, rxBytes, txBytes, nil } -func (a *Agent) checkLimit(cfg *BandwidthPerPeriod, rx, tx int64) (enforce, warning bool) { - if cfg.Limit.Rx != Unlimited && rx > cfg.Limit.Rx { - return true, false +func (a *Agent) bandwidthClassPoints(bwc store.BandwidthClass) int { + points := 0 + if !bwc.IsGlobal() { + points++ } - if cfg.Limit.Tx != Unlimited && tx > cfg.Limit.Tx { - return true, false + if bwc.GetLimitAction() == store.WarningLimitAction { + points++ } - if cfg.Limit.Total != Unlimited && rx+tx > cfg.Limit.Total { - return true, false + if bwc.GetLimitAction() == store.LimitLimitAction { + points += 2 } - - if cfg.Warning.Rx != Unlimited && rx > cfg.Warning.Rx { - return false, true + if bwc.GetShareMode() != "" { + points += 5 } - if cfg.Warning.Tx != Unlimited && tx > cfg.Warning.Tx { - return false, true + if bwc.GetBackendMode() != "" { + points += 10 } - if cfg.Warning.Total != Unlimited && rx+tx > cfg.Warning.Total { - return false, true - } - - return false, false + return points +} + +func (a *Agent) limitExceeded(rx, tx int64, bwc store.BandwidthClass) bool { + if bwc.GetTxBytes() != Unlimited && tx >= bwc.GetTxBytes() { + return true + } + if bwc.GetRxBytes() != Unlimited && rx >= bwc.GetRxBytes() { + return true + } + if bwc.GetTxBytes() != Unlimited && bwc.GetRxBytes() != Unlimited && tx+rx >= bwc.GetTxBytes()+bwc.GetRxBytes() { + return true + } + return false } diff --git a/controller/limits/limitAction.go b/controller/limits/limitAction.go index d4e5ee05..6b857ee0 100644 --- a/controller/limits/limitAction.go +++ b/controller/limits/limitAction.go @@ -17,7 +17,7 @@ func newLimitAction(str *store.Store, zCfg *zrokEdgeSdk.Config) *limitAction { return &limitAction{str, zCfg} } -func (a *limitAction) HandleAccount(acct *store.Account, _, _ int64, _ *BandwidthPerPeriod, trx *sqlx.Tx) error { +func (a *limitAction) HandleAccount(acct *store.Account, _, _ int64, _ store.BandwidthClass, trx *sqlx.Tx) error { logrus.Infof("limiting '%v'", acct.Email) envs, err := a.str.FindEnvironmentsForAccount(acct.Id, trx) diff --git a/controller/limits/limitClasses.go b/controller/limits/limitClasses.go index e6a22c19..a2d11c19 100644 --- a/controller/limits/limitClasses.go +++ b/controller/limits/limitClasses.go @@ -30,6 +30,10 @@ func (bc *configBandwidthClass) IsGlobal() bool { return true } +func (bc *configBandwidthClass) GetLimitClassId() int { + return -1 +} + func (bc *configBandwidthClass) GetShareMode() sdk.ShareMode { return "" } diff --git a/controller/limits/model.go b/controller/limits/model.go index c13f8aae..7044c7d4 100644 --- a/controller/limits/model.go +++ b/controller/limits/model.go @@ -6,13 +6,13 @@ import ( ) type AccountAction interface { - HandleAccount(a *store.Account, rxBytes, txBytes int64, limit *BandwidthPerPeriod, trx *sqlx.Tx) error + HandleAccount(a *store.Account, rxBytes, txBytes int64, limit store.BandwidthClass, trx *sqlx.Tx) error } type EnvironmentAction interface { - HandleEnvironment(e *store.Environment, rxBytes, txBytes int64, limit *BandwidthPerPeriod, trx *sqlx.Tx) error + HandleEnvironment(e *store.Environment, rxBytes, txBytes int64, limit store.BandwidthClass, trx *sqlx.Tx) error } type ShareAction interface { - HandleShare(s *store.Share, rxBytes, txBytes int64, limit *BandwidthPerPeriod, trx *sqlx.Tx) error + HandleShare(s *store.Share, rxBytes, txBytes int64, limit store.BandwidthClass, trx *sqlx.Tx) error } diff --git a/controller/limits/relaxAction.go b/controller/limits/relaxAction.go index db3c33a4..7700e3e2 100644 --- a/controller/limits/relaxAction.go +++ b/controller/limits/relaxAction.go @@ -19,7 +19,7 @@ func newRelaxAction(str *store.Store, zCfg *zrokEdgeSdk.Config) *relaxAction { return &relaxAction{str, zCfg} } -func (a *relaxAction) HandleAccount(acct *store.Account, _, _ int64, _ *BandwidthPerPeriod, trx *sqlx.Tx) error { +func (a *relaxAction) HandleAccount(acct *store.Account, _, _ int64, _ store.BandwidthClass, trx *sqlx.Tx) error { logrus.Infof("relaxing '%v'", acct.Email) envs, err := a.str.FindEnvironmentsForAccount(acct.Id, trx) diff --git a/controller/limits/warningAction.go b/controller/limits/warningAction.go index 829f629c..8e1841fd 100644 --- a/controller/limits/warningAction.go +++ b/controller/limits/warningAction.go @@ -7,6 +7,7 @@ import ( "github.com/openziti/zrok/util" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "time" ) type warningAction struct { @@ -18,27 +19,27 @@ func newWarningAction(cfg *emailUi.Config, str *store.Store) *warningAction { return &warningAction{str, cfg} } -func (a *warningAction) HandleAccount(acct *store.Account, rxBytes, txBytes int64, limit *BandwidthPerPeriod, _ *sqlx.Tx) error { +func (a *warningAction) HandleAccount(acct *store.Account, rxBytes, txBytes int64, limit store.BandwidthClass, _ *sqlx.Tx) error { logrus.Infof("warning '%v'", acct.Email) if a.cfg != nil { rxLimit := "(unlimited bytes)" - if limit.Limit.Rx != Unlimited { - rxLimit = util.BytesToSize(limit.Limit.Rx) + if limit.GetRxBytes() != Unlimited { + rxLimit = util.BytesToSize(limit.GetRxBytes()) } txLimit := "(unlimited bytes)" - if limit.Limit.Tx != Unlimited { - txLimit = util.BytesToSize(limit.Limit.Tx) + if limit.GetTxBytes() != Unlimited { + txLimit = util.BytesToSize(limit.GetTxBytes()) } totalLimit := "(unlimited bytes)" - if limit.Limit.Total != Unlimited { - totalLimit = util.BytesToSize(limit.Limit.Total) + if limit.GetTotalBytes() != Unlimited { + totalLimit = util.BytesToSize(limit.GetTotalBytes()) } detail := newDetailMessage() detail = detail.append("Your account has received %v and sent %v (for a total of %v), which has triggered a transfer limit warning.", util.BytesToSize(rxBytes), util.BytesToSize(txBytes), util.BytesToSize(rxBytes+txBytes)) - detail = detail.append("This zrok instance only allows an account to receive %v, send %v, totalling not more than %v for each %v.", rxLimit, txLimit, totalLimit, limit.Period) - detail = detail.append("If you exceed the transfer limit, access to your shares will be temporarily disabled (until the last %v falls below the transfer limit)", limit.Period) + detail = detail.append("This zrok instance only allows an account to receive %v, send %v, totalling not more than %v for each %v.", rxLimit, txLimit, totalLimit, time.Duration(limit.GetPeriodMinutes())*time.Minute) + detail = detail.append("If you exceed the transfer limit, access to your shares will be temporarily disabled (until the last %v falls below the transfer limit)", time.Duration(limit.GetPeriodMinutes())*time.Minute) if err := sendLimitWarningEmail(a.cfg, acct.Email, detail); err != nil { return errors.Wrapf(err, "error sending limit warning email to '%v'", acct.Email) diff --git a/controller/store/appliedLimitClass.go b/controller/store/appliedLimitClass.go index 56c62aa2..edf25b61 100644 --- a/controller/store/appliedLimitClass.go +++ b/controller/store/appliedLimitClass.go @@ -23,7 +23,7 @@ func (str *Store) ApplyLimitClass(lc *AppliedLimitClass, trx *sqlx.Tx) (int, err return id, nil } -func (str *Store) FindLimitClassesForAccount(acctId int, trx *sqlx.Tx) ([]*LimitClass, error) { +func (str *Store) FindAppliedLimitClassesForAccount(acctId int, trx *sqlx.Tx) ([]*LimitClass, error) { rows, err := trx.Queryx("select limit_classes.* from applied_limit_classes, limit_classes where applied_limit_classes.account_id = $1 and applied_limit_classes.limit_class_id = limit_classes.id", acctId) if err != nil { return nil, errors.Wrap(err, "error finding limit classes for account") diff --git a/controller/store/bandwidthLimitJournal.go b/controller/store/bandwidthLimitJournal.go index eaeeb6eb..e3fb43c3 100644 --- a/controller/store/bandwidthLimitJournal.go +++ b/controller/store/bandwidthLimitJournal.go @@ -42,6 +42,54 @@ func (str *Store) FindLatestBandwidthLimitJournal(acctId int, trx *sqlx.Tx) (*Ba return j, nil } +func (str *Store) IsBandwidthLimitJournalEmptyForGlobal(acctId int, trx *sqlx.Tx) (bool, error) { + count := 0 + if err := trx.QueryRowx("select count(0) from bandwidth_limit_journal where account_id = $1 and limit_class_id is null", acctId).Scan(&count); err != nil { + return false, err + } + return count == 0, nil +} + +func (str *Store) FindLatestBandwidthLimitJournalForGlobal(acctId int, trx *sqlx.Tx) (*BandwidthLimitJournalEntry, error) { + j := &BandwidthLimitJournalEntry{} + if err := trx.QueryRowx("select * from bandwidth_limit_journal where account_id = $1 and limit_class_id is null order by id desc limit 1", acctId).Scan(&j); err != nil { + return nil, errors.Wrap(err, "error finding bandwidth_limit_journal by account_id for global") + } + return j, nil +} + +func (str *Store) IsBandwidthLimitJournalEmptyForLimitClass(acctId, lcId int, trx *sqlx.Tx) (bool, error) { + count := 0 + if err := trx.QueryRowx("select count(0) from bandwidth_limit_journal where account_id = $1 and limit_class_id = $2", acctId, lcId).Scan(&count); err != nil { + return false, err + } + return count == 0, nil +} + +func (str *Store) FindLatestBandwidthLimitJournalForLimitClass(acctId, lcId int, trx *sqlx.Tx) (*BandwidthLimitJournalEntry, error) { + j := &BandwidthLimitJournalEntry{} + if err := trx.QueryRowx("select * from bandwidth_limit_journal where account_id = $1 and limit_class_id = $2 order by id desc limit 1", acctId, lcId).StructScan(j); err != nil { + return nil, errors.Wrap(err, "error finding bandwidth_limit_journal by account_id and limit_class_id") + } + return j, nil +} + +func (str *Store) FindAllBandwidthLimitJournal(trx *sqlx.Tx) ([]*BandwidthLimitJournalEntry, error) { + rows, err := trx.Queryx("select * from bandwidth_limit_journal") + if err != nil { + return nil, errors.Wrap(err, "error finding all from bandwidth_limit_journal") + } + var jes []*BandwidthLimitJournalEntry + for rows.Next() { + je := &BandwidthLimitJournalEntry{} + if err := rows.StructScan(je); err != nil { + return nil, errors.Wrap(err, "error scanning bandwidth_limit_journal") + } + jes = append(jes, je) + } + return jes, nil +} + func (str *Store) FindAllLatestBandwidthLimitJournal(trx *sqlx.Tx) ([]*BandwidthLimitJournalEntry, error) { rows, err := trx.Queryx("select id, account_id, limit_class_id, action, rx_bytes, tx_bytes, created_at, updated_at from bandwidth_limit_journal where id in (select max(id) as id from bandwidth_limit_journal group by account_id)") if err != nil { @@ -64,3 +112,10 @@ func (str *Store) DeleteBandwidthLimitJournal(acctId int, trx *sqlx.Tx) error { } return nil } + +func (str *Store) DeleteBandwidthLimitJournalEntryForLimitClass(acctId int, lcId *int, trx *sqlx.Tx) error { + if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1 and limit_class_id = $2", acctId, lcId); err != nil { + return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d and limit_class_id = %d", acctId, lcId) + } + return nil +} diff --git a/controller/store/limitClass.go b/controller/store/limitClass.go index 9f9ae05e..f9a64e86 100644 --- a/controller/store/limitClass.go +++ b/controller/store/limitClass.go @@ -9,6 +9,7 @@ import ( type BandwidthClass interface { IsGlobal() bool + GetLimitClassId() int GetShareMode() sdk.ShareMode GetBackendMode() sdk.BackendMode GetPeriodMinutes() int @@ -33,35 +34,39 @@ type LimitClass struct { LimitAction LimitAction } -func (lc *LimitClass) IsGlobal() bool { +func (lc LimitClass) IsGlobal() bool { return false } -func (lc *LimitClass) GetShareMode() sdk.ShareMode { +func (lc LimitClass) GetLimitClassId() int { + return lc.Id +} + +func (lc LimitClass) GetShareMode() sdk.ShareMode { return lc.ShareMode } -func (lc *LimitClass) GetBackendMode() sdk.BackendMode { +func (lc LimitClass) GetBackendMode() sdk.BackendMode { return lc.BackendMode } -func (lc *LimitClass) GetPeriodMinutes() int { +func (lc LimitClass) GetPeriodMinutes() int { return lc.PeriodMinutes } -func (lc *LimitClass) GetRxBytes() int64 { +func (lc LimitClass) GetRxBytes() int64 { return lc.RxBytes } -func (lc *LimitClass) GetTxBytes() int64 { +func (lc LimitClass) GetTxBytes() int64 { return lc.TxBytes } -func (lc *LimitClass) GetTotalBytes() int64 { +func (lc LimitClass) GetTotalBytes() int64 { return lc.TotalBytes } -func (lc *LimitClass) GetLimitAction() LimitAction { +func (lc LimitClass) GetLimitAction() LimitAction { return lc.LimitAction } @@ -74,6 +79,8 @@ func (lc LimitClass) String() string { return string(out) } +var _ BandwidthClass = (*LimitClass)(nil) + func (str *Store) CreateLimitClass(lc *LimitClass, trx *sqlx.Tx) (int, error) { stmt, err := trx.Prepare("insert into limit_classes (share_mode, backend_mode, environments, shares, reserved_shares, unique_names, period_minutes, rx_bytes, tx_bytes, total_bytes, limit_action) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) returning id") if err != nil { @@ -85,3 +92,11 @@ func (str *Store) CreateLimitClass(lc *LimitClass, trx *sqlx.Tx) (int, error) { } return id, nil } + +func (str *Store) GetLimitClass(lcId int, trx *sqlx.Tx) (*LimitClass, error) { + lc := &LimitClass{} + if err := trx.QueryRowx("select * from limit_classes where id = $1", lcId).StructScan(lc); err != nil { + return nil, errors.Wrap(err, "error selecting limit_class by id") + } + return lc, nil +} diff --git a/controller/store/share.go b/controller/store/share.go index 305d76e3..c12aa529 100644 --- a/controller/store/share.go +++ b/controller/store/share.go @@ -65,6 +65,14 @@ func (str *Store) FindShareWithToken(shrToken string, tx *sqlx.Tx) (*Share, erro return shr, nil } +func (str *Store) FindShareWithTokenEvenIfDeleted(shrToken string, tx *sqlx.Tx) (*Share, error) { + shr := &Share{} + if err := tx.QueryRowx("select * from shares where token = $1", shrToken).StructScan(shr); err != nil { + return nil, errors.Wrap(err, "error selecting share by token, even if deleted") + } + return shr, nil +} + func (str *Store) ShareWithTokenExists(shrToken string, tx *sqlx.Tx) (bool, error) { count := 0 if err := tx.QueryRowx("select count(0) from shares where token = $1 and not deleted", shrToken).Scan(&count); err != nil {