massive bandwidth limits rewrite to support limit classes (#606)

This commit is contained in:
Michael Quigley 2024-06-04 14:06:44 -04:00
parent bb63921e42
commit 97dbd197d6
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
10 changed files with 314 additions and 152 deletions

View File

@ -217,87 +217,82 @@ func (a *Agent) enforce(u *metrics.Usage) error {
return nil
}
shr, err := a.str.FindShareWithToken(u.ShareToken, trx)
shr, err := a.str.FindShareWithTokenEvenIfDeleted(u.ShareToken, trx)
if err != nil {
return err
}
logrus.Infof("share: '%v', shareMode: '%v', backendMode: '%v'", shr.Token, shr.ShareMode, shr.BackendMode)
logrus.Debugf("share: '%v', shareMode: '%v', backendMode: '%v'", shr.Token, shr.ShareMode, shr.BackendMode)
if enforce, warning, rxBytes, txBytes, err := a.checkBandwidthLimit(u.AccountId); err == nil {
if enforce {
enforced := false
var enforcedAt time.Time
if empty, err := a.str.IsBandwidthLimitJournalEmpty(int(u.AccountId), trx); err == nil && !empty {
if latest, err := a.str.FindLatestBandwidthLimitJournal(int(u.AccountId), trx); err == nil {
enforced = latest.Action == store.LimitLimitAction
alcs, err := a.str.FindAppliedLimitClassesForAccount(int(u.AccountId), trx)
if err != nil {
return err
}
exceededLc, rxBytes, txBytes, err := a.isOverLimitClass(u, alcs)
if err != nil {
return errors.Wrap(err, "error checking limit classes")
}
if exceededLc != nil {
enforced := false
var enforcedAt time.Time
if exceededLc.IsGlobal() {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(int(u.AccountId), trx); err == nil && !empty {
if latest, err := a.str.FindLatestBandwidthLimitJournalForGlobal(int(u.AccountId), trx); err == nil {
enforced = latest.Action == exceededLc.GetLimitAction()
enforcedAt = latest.UpdatedAt
}
}
if !enforced {
_, err := a.str.CreateBandwidthLimitJournalEntry(&store.BandwidthLimitJournalEntry{
AccountId: int(u.AccountId),
RxBytes: rxBytes,
TxBytes: txBytes,
Action: store.LimitLimitAction,
}, trx)
if err != nil {
return err
} else {
if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil && !empty {
if latest, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(int(u.AccountId), exceededLc.GetLimitClassId(), trx); err == nil {
enforced = latest.Action == exceededLc.GetLimitAction()
enforcedAt = latest.UpdatedAt
}
acct, err := a.str.GetAccount(int(u.AccountId), trx)
if err != nil {
return err
}
for _, action := range a.limitActions {
if err := action.HandleAccount(acct, rxBytes, txBytes, a.cfg.Bandwidth, trx); err != nil {
return errors.Wrapf(err, "%v", reflect.TypeOf(action).String())
}
}
if err := trx.Commit(); err != nil {
return err
}
} else {
logrus.Debugf("already enforced limit for account '#%d' at %v", u.AccountId, enforcedAt)
}
} else if warning {
warned := false
var warnedAt time.Time
if empty, err := a.str.IsBandwidthLimitJournalEmpty(int(u.AccountId), trx); err == nil && !empty {
if latest, err := a.str.FindLatestBandwidthLimitJournal(int(u.AccountId), trx); err == nil {
warned = latest.Action == store.WarningLimitAction || latest.Action == store.LimitLimitAction
warnedAt = latest.UpdatedAt
}
}
if !warned {
_, err := a.str.CreateBandwidthLimitJournalEntry(&store.BandwidthLimitJournalEntry{
AccountId: int(u.AccountId),
RxBytes: rxBytes,
TxBytes: txBytes,
Action: store.WarningLimitAction,
}, trx)
if err != nil {
return err
}
acct, err := a.str.GetAccount(int(u.AccountId), trx)
if err != nil {
return err
}
for _, action := range a.warningActions {
if err := action.HandleAccount(acct, rxBytes, txBytes, a.cfg.Bandwidth, trx); err != nil {
return errors.Wrapf(err, "%v", reflect.TypeOf(action).String())
}
}
if err := trx.Commit(); err != nil {
return err
}
} else {
logrus.Debugf("already warned account '#%d' at %v", u.AccountId, warnedAt)
}
}
} else {
logrus.Error(err)
if !enforced {
je := &store.BandwidthLimitJournalEntry{
AccountId: int(u.AccountId),
RxBytes: rxBytes,
TxBytes: txBytes,
Action: exceededLc.GetLimitAction(),
}
if !exceededLc.IsGlobal() {
lcId := exceededLc.GetLimitClassId()
je.LimitClassId = &lcId
}
_, err := a.str.CreateBandwidthLimitJournalEntry(je, trx)
if err != nil {
return err
}
acct, err := a.str.GetAccount(int(u.AccountId), trx)
if err != nil {
return err
}
switch exceededLc.GetLimitAction() {
case store.LimitLimitAction:
for _, limitAction := range a.limitActions {
if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, trx); err != nil {
return errors.Wrapf(err, "%v", reflect.TypeOf(limitAction).String())
}
}
case store.WarningLimitAction:
for _, warningAction := range a.warningActions {
if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededLc, trx); err != nil {
return errors.Wrapf(err, "%v", reflect.TypeOf(warningAction).String())
}
}
}
if err := trx.Commit(); err != nil {
return err
}
} else {
logrus.Debugf("already enforced limit for account '%d' at %v", u.AccountId, enforcedAt)
}
}
return nil
@ -314,36 +309,77 @@ func (a *Agent) relax() error {
commit := false
if aljs, err := a.str.FindAllLatestBandwidthLimitJournal(trx); err == nil {
for _, alj := range aljs {
if acct, err := a.str.GetAccount(alj.AccountId, trx); err == nil {
if alj.Action == store.WarningLimitAction || alj.Action == store.LimitLimitAction {
if enforce, warning, rxBytes, txBytes, err := a.checkBandwidthLimit(int64(alj.AccountId)); err == nil {
if !enforce && !warning {
if alj.Action == store.LimitLimitAction {
// run relax actions for account
for _, action := range a.relaxActions {
if err := action.HandleAccount(acct, rxBytes, txBytes, a.cfg.Bandwidth, trx); err != nil {
return errors.Wrapf(err, "%v", reflect.TypeOf(action).String())
}
}
} else {
logrus.Infof("relaxing warning for '%v'", acct.Email)
}
if err := a.str.DeleteBandwidthLimitJournal(acct.Id, trx); err == nil {
commit = true
} else {
logrus.Errorf("error deleting account_limit_journal for '%v': %v", acct.Email, err)
}
} else {
logrus.Infof("account '%v' still over limit", acct.Email)
}
} else {
logrus.Errorf("error checking account limit for '%v': %v", acct.Email, err)
}
if bwjes, err := a.str.FindAllBandwidthLimitJournal(trx); err == nil {
periodBw := make(map[int]struct {
rx int64
tx int64
})
accounts := make(map[int]*store.Account)
for _, bwje := range bwjes {
if _, found := accounts[bwje.AccountId]; !found {
if acct, err := a.str.GetAccount(bwje.AccountId, trx); err == nil {
accounts[bwje.AccountId] = acct
} else {
return err
}
}
var bwc store.BandwidthClass
if bwje.LimitClassId != nil {
globalBwcs := newConfigBandwidthClasses(a.cfg.Bandwidth)
if bwje.Action == store.WarningLimitAction {
bwc = globalBwcs[0]
} else {
bwc = globalBwcs[1]
}
} else {
logrus.Errorf("error getting account for '#%d': %v", alj.AccountId, err)
lc, err := a.str.GetLimitClass(*bwje.LimitClassId, trx)
if err != nil {
return err
}
bwc = lc
}
if _, found := periodBw[bwc.GetPeriodMinutes()]; !found {
rx, tx, err := a.ifx.totalRxTxForAccount(int64(bwje.AccountId), time.Duration(bwc.GetPeriodMinutes())*time.Minute)
if err != nil {
return err
}
periodBw[bwc.GetPeriodMinutes()] = struct {
rx int64
tx int64
}{
rx: rx,
tx: tx,
}
}
used := periodBw[bwc.GetPeriodMinutes()]
if !a.limitExceeded(used.rx, used.tx, bwc) {
if bwc.GetLimitAction() == store.LimitLimitAction {
for _, action := range a.relaxActions {
if err := action.HandleAccount(accounts[bwje.AccountId], used.rx, used.tx, bwc, trx); err != nil {
return errors.Wrapf(err, "%v", reflect.TypeOf(action).String())
}
}
} else {
logrus.Infof("relaxing warning for '%v'", accounts[bwje.AccountId].Email)
}
var lcId *int
if !bwc.IsGlobal() {
newLcId := 0
newLcId = bwc.GetLimitClassId()
lcId = &newLcId
}
if err := a.str.DeleteBandwidthLimitJournalEntryForLimitClass(bwje.AccountId, lcId, trx); err == nil {
commit = true
} else {
logrus.Errorf("error deleting bandwidth limit journal entry for '%v': %v", accounts[bwje.AccountId].Email, err)
}
} else {
logrus.Infof("account '%v' still over limit: %v", accounts[bwje.AccountId].Email, bwc)
}
}
} else {
@ -359,44 +395,87 @@ func (a *Agent) relax() error {
return nil
}
func (a *Agent) checkBandwidthLimit(acctId int64) (enforce, warning bool, rxBytes, txBytes int64, err error) {
period := 24 * time.Hour
limit := DefaultBandwidthPerPeriod()
if a.cfg.Bandwidth != nil {
limit = a.cfg.Bandwidth
func (a *Agent) isOverLimitClass(u *metrics.Usage, alcs []*store.LimitClass) (store.BandwidthClass, int64, int64, error) {
periodBw := make(map[int]struct {
rx int64
tx int64
})
var allBwcs []store.BandwidthClass
for _, alc := range alcs {
allBwcs = append(allBwcs, alc)
}
if limit.Period > 0 {
period = limit.Period
}
rx, tx, err := a.ifx.totalRxTxForAccount(acctId, period)
if err != nil {
logrus.Error(err)
for _, globBwc := range newConfigBandwidthClasses(a.cfg.Bandwidth) {
allBwcs = append(allBwcs, globBwc)
}
enforce, warning = a.checkLimit(limit, rx, tx)
return enforce, warning, rx, tx, nil
// find period data for each class
for _, bwc := range allBwcs {
if _, found := periodBw[bwc.GetPeriodMinutes()]; !found {
rx, tx, err := a.ifx.totalRxTxForAccount(u.AccountId, time.Minute*time.Duration(bwc.GetPeriodMinutes()))
if err != nil {
return nil, 0, 0, errors.Wrapf(err, "error getting rx/tx for account '%d'", u.AccountId)
}
periodBw[bwc.GetPeriodMinutes()] = struct {
rx int64
tx int64
}{
rx: rx,
tx: tx,
}
}
}
// find the highest, most specific limit class that has been exceeded
var selectedLc store.BandwidthClass
selectedLcPoints := -1
var rxBytes int64
var txBytes int64
for _, bwc := range allBwcs {
points := a.bandwidthClassPoints(bwc)
if points >= selectedLcPoints {
period := periodBw[bwc.GetPeriodMinutes()]
if a.limitExceeded(period.rx, period.tx, bwc) {
selectedLc = bwc
selectedLcPoints = points
rxBytes = period.rx
txBytes = period.tx
}
}
}
return selectedLc, rxBytes, txBytes, nil
}
func (a *Agent) checkLimit(cfg *BandwidthPerPeriod, rx, tx int64) (enforce, warning bool) {
if cfg.Limit.Rx != Unlimited && rx > cfg.Limit.Rx {
return true, false
func (a *Agent) bandwidthClassPoints(bwc store.BandwidthClass) int {
points := 0
if !bwc.IsGlobal() {
points++
}
if cfg.Limit.Tx != Unlimited && tx > cfg.Limit.Tx {
return true, false
if bwc.GetLimitAction() == store.WarningLimitAction {
points++
}
if cfg.Limit.Total != Unlimited && rx+tx > cfg.Limit.Total {
return true, false
if bwc.GetLimitAction() == store.LimitLimitAction {
points += 2
}
if cfg.Warning.Rx != Unlimited && rx > cfg.Warning.Rx {
return false, true
if bwc.GetShareMode() != "" {
points += 5
}
if cfg.Warning.Tx != Unlimited && tx > cfg.Warning.Tx {
return false, true
if bwc.GetBackendMode() != "" {
points += 10
}
if cfg.Warning.Total != Unlimited && rx+tx > cfg.Warning.Total {
return false, true
}
return false, false
return points
}
func (a *Agent) limitExceeded(rx, tx int64, bwc store.BandwidthClass) bool {
if bwc.GetTxBytes() != Unlimited && tx >= bwc.GetTxBytes() {
return true
}
if bwc.GetRxBytes() != Unlimited && rx >= bwc.GetRxBytes() {
return true
}
if bwc.GetTxBytes() != Unlimited && bwc.GetRxBytes() != Unlimited && tx+rx >= bwc.GetTxBytes()+bwc.GetRxBytes() {
return true
}
return false
}

View File

@ -17,7 +17,7 @@ func newLimitAction(str *store.Store, zCfg *zrokEdgeSdk.Config) *limitAction {
return &limitAction{str, zCfg}
}
func (a *limitAction) HandleAccount(acct *store.Account, _, _ int64, _ *BandwidthPerPeriod, trx *sqlx.Tx) error {
func (a *limitAction) HandleAccount(acct *store.Account, _, _ int64, _ store.BandwidthClass, trx *sqlx.Tx) error {
logrus.Infof("limiting '%v'", acct.Email)
envs, err := a.str.FindEnvironmentsForAccount(acct.Id, trx)

View File

@ -30,6 +30,10 @@ func (bc *configBandwidthClass) IsGlobal() bool {
return true
}
func (bc *configBandwidthClass) GetLimitClassId() int {
return -1
}
func (bc *configBandwidthClass) GetShareMode() sdk.ShareMode {
return ""
}

View File

@ -6,13 +6,13 @@ import (
)
type AccountAction interface {
HandleAccount(a *store.Account, rxBytes, txBytes int64, limit *BandwidthPerPeriod, trx *sqlx.Tx) error
HandleAccount(a *store.Account, rxBytes, txBytes int64, limit store.BandwidthClass, trx *sqlx.Tx) error
}
type EnvironmentAction interface {
HandleEnvironment(e *store.Environment, rxBytes, txBytes int64, limit *BandwidthPerPeriod, trx *sqlx.Tx) error
HandleEnvironment(e *store.Environment, rxBytes, txBytes int64, limit store.BandwidthClass, trx *sqlx.Tx) error
}
type ShareAction interface {
HandleShare(s *store.Share, rxBytes, txBytes int64, limit *BandwidthPerPeriod, trx *sqlx.Tx) error
HandleShare(s *store.Share, rxBytes, txBytes int64, limit store.BandwidthClass, trx *sqlx.Tx) error
}

View File

@ -19,7 +19,7 @@ func newRelaxAction(str *store.Store, zCfg *zrokEdgeSdk.Config) *relaxAction {
return &relaxAction{str, zCfg}
}
func (a *relaxAction) HandleAccount(acct *store.Account, _, _ int64, _ *BandwidthPerPeriod, trx *sqlx.Tx) error {
func (a *relaxAction) HandleAccount(acct *store.Account, _, _ int64, _ store.BandwidthClass, trx *sqlx.Tx) error {
logrus.Infof("relaxing '%v'", acct.Email)
envs, err := a.str.FindEnvironmentsForAccount(acct.Id, trx)

View File

@ -7,6 +7,7 @@ import (
"github.com/openziti/zrok/util"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"time"
)
type warningAction struct {
@ -18,27 +19,27 @@ func newWarningAction(cfg *emailUi.Config, str *store.Store) *warningAction {
return &warningAction{str, cfg}
}
func (a *warningAction) HandleAccount(acct *store.Account, rxBytes, txBytes int64, limit *BandwidthPerPeriod, _ *sqlx.Tx) error {
func (a *warningAction) HandleAccount(acct *store.Account, rxBytes, txBytes int64, limit store.BandwidthClass, _ *sqlx.Tx) error {
logrus.Infof("warning '%v'", acct.Email)
if a.cfg != nil {
rxLimit := "(unlimited bytes)"
if limit.Limit.Rx != Unlimited {
rxLimit = util.BytesToSize(limit.Limit.Rx)
if limit.GetRxBytes() != Unlimited {
rxLimit = util.BytesToSize(limit.GetRxBytes())
}
txLimit := "(unlimited bytes)"
if limit.Limit.Tx != Unlimited {
txLimit = util.BytesToSize(limit.Limit.Tx)
if limit.GetTxBytes() != Unlimited {
txLimit = util.BytesToSize(limit.GetTxBytes())
}
totalLimit := "(unlimited bytes)"
if limit.Limit.Total != Unlimited {
totalLimit = util.BytesToSize(limit.Limit.Total)
if limit.GetTotalBytes() != Unlimited {
totalLimit = util.BytesToSize(limit.GetTotalBytes())
}
detail := newDetailMessage()
detail = detail.append("Your account has received %v and sent %v (for a total of %v), which has triggered a transfer limit warning.", util.BytesToSize(rxBytes), util.BytesToSize(txBytes), util.BytesToSize(rxBytes+txBytes))
detail = detail.append("This zrok instance only allows an account to receive %v, send %v, totalling not more than %v for each %v.", rxLimit, txLimit, totalLimit, limit.Period)
detail = detail.append("If you exceed the transfer limit, access to your shares will be temporarily disabled (until the last %v falls below the transfer limit)", limit.Period)
detail = detail.append("This zrok instance only allows an account to receive %v, send %v, totalling not more than %v for each %v.", rxLimit, txLimit, totalLimit, time.Duration(limit.GetPeriodMinutes())*time.Minute)
detail = detail.append("If you exceed the transfer limit, access to your shares will be temporarily disabled (until the last %v falls below the transfer limit)", time.Duration(limit.GetPeriodMinutes())*time.Minute)
if err := sendLimitWarningEmail(a.cfg, acct.Email, detail); err != nil {
return errors.Wrapf(err, "error sending limit warning email to '%v'", acct.Email)

View File

@ -23,7 +23,7 @@ func (str *Store) ApplyLimitClass(lc *AppliedLimitClass, trx *sqlx.Tx) (int, err
return id, nil
}
func (str *Store) FindLimitClassesForAccount(acctId int, trx *sqlx.Tx) ([]*LimitClass, error) {
func (str *Store) FindAppliedLimitClassesForAccount(acctId int, trx *sqlx.Tx) ([]*LimitClass, error) {
rows, err := trx.Queryx("select limit_classes.* from applied_limit_classes, limit_classes where applied_limit_classes.account_id = $1 and applied_limit_classes.limit_class_id = limit_classes.id", acctId)
if err != nil {
return nil, errors.Wrap(err, "error finding limit classes for account")

View File

@ -42,6 +42,54 @@ func (str *Store) FindLatestBandwidthLimitJournal(acctId int, trx *sqlx.Tx) (*Ba
return j, nil
}
func (str *Store) IsBandwidthLimitJournalEmptyForGlobal(acctId int, trx *sqlx.Tx) (bool, error) {
count := 0
if err := trx.QueryRowx("select count(0) from bandwidth_limit_journal where account_id = $1 and limit_class_id is null", acctId).Scan(&count); err != nil {
return false, err
}
return count == 0, nil
}
func (str *Store) FindLatestBandwidthLimitJournalForGlobal(acctId int, trx *sqlx.Tx) (*BandwidthLimitJournalEntry, error) {
j := &BandwidthLimitJournalEntry{}
if err := trx.QueryRowx("select * from bandwidth_limit_journal where account_id = $1 and limit_class_id is null order by id desc limit 1", acctId).Scan(&j); err != nil {
return nil, errors.Wrap(err, "error finding bandwidth_limit_journal by account_id for global")
}
return j, nil
}
func (str *Store) IsBandwidthLimitJournalEmptyForLimitClass(acctId, lcId int, trx *sqlx.Tx) (bool, error) {
count := 0
if err := trx.QueryRowx("select count(0) from bandwidth_limit_journal where account_id = $1 and limit_class_id = $2", acctId, lcId).Scan(&count); err != nil {
return false, err
}
return count == 0, nil
}
func (str *Store) FindLatestBandwidthLimitJournalForLimitClass(acctId, lcId int, trx *sqlx.Tx) (*BandwidthLimitJournalEntry, error) {
j := &BandwidthLimitJournalEntry{}
if err := trx.QueryRowx("select * from bandwidth_limit_journal where account_id = $1 and limit_class_id = $2 order by id desc limit 1", acctId, lcId).StructScan(j); err != nil {
return nil, errors.Wrap(err, "error finding bandwidth_limit_journal by account_id and limit_class_id")
}
return j, nil
}
func (str *Store) FindAllBandwidthLimitJournal(trx *sqlx.Tx) ([]*BandwidthLimitJournalEntry, error) {
rows, err := trx.Queryx("select * from bandwidth_limit_journal")
if err != nil {
return nil, errors.Wrap(err, "error finding all from bandwidth_limit_journal")
}
var jes []*BandwidthLimitJournalEntry
for rows.Next() {
je := &BandwidthLimitJournalEntry{}
if err := rows.StructScan(je); err != nil {
return nil, errors.Wrap(err, "error scanning bandwidth_limit_journal")
}
jes = append(jes, je)
}
return jes, nil
}
func (str *Store) FindAllLatestBandwidthLimitJournal(trx *sqlx.Tx) ([]*BandwidthLimitJournalEntry, error) {
rows, err := trx.Queryx("select id, account_id, limit_class_id, action, rx_bytes, tx_bytes, created_at, updated_at from bandwidth_limit_journal where id in (select max(id) as id from bandwidth_limit_journal group by account_id)")
if err != nil {
@ -64,3 +112,10 @@ func (str *Store) DeleteBandwidthLimitJournal(acctId int, trx *sqlx.Tx) error {
}
return nil
}
func (str *Store) DeleteBandwidthLimitJournalEntryForLimitClass(acctId int, lcId *int, trx *sqlx.Tx) error {
if _, err := trx.Exec("delete from bandwidth_limit_journal where account_id = $1 and limit_class_id = $2", acctId, lcId); err != nil {
return errors.Wrapf(err, "error deleting from bandwidth_limit_journal for account_id = %d and limit_class_id = %d", acctId, lcId)
}
return nil
}

View File

@ -9,6 +9,7 @@ import (
type BandwidthClass interface {
IsGlobal() bool
GetLimitClassId() int
GetShareMode() sdk.ShareMode
GetBackendMode() sdk.BackendMode
GetPeriodMinutes() int
@ -33,35 +34,39 @@ type LimitClass struct {
LimitAction LimitAction
}
func (lc *LimitClass) IsGlobal() bool {
func (lc LimitClass) IsGlobal() bool {
return false
}
func (lc *LimitClass) GetShareMode() sdk.ShareMode {
func (lc LimitClass) GetLimitClassId() int {
return lc.Id
}
func (lc LimitClass) GetShareMode() sdk.ShareMode {
return lc.ShareMode
}
func (lc *LimitClass) GetBackendMode() sdk.BackendMode {
func (lc LimitClass) GetBackendMode() sdk.BackendMode {
return lc.BackendMode
}
func (lc *LimitClass) GetPeriodMinutes() int {
func (lc LimitClass) GetPeriodMinutes() int {
return lc.PeriodMinutes
}
func (lc *LimitClass) GetRxBytes() int64 {
func (lc LimitClass) GetRxBytes() int64 {
return lc.RxBytes
}
func (lc *LimitClass) GetTxBytes() int64 {
func (lc LimitClass) GetTxBytes() int64 {
return lc.TxBytes
}
func (lc *LimitClass) GetTotalBytes() int64 {
func (lc LimitClass) GetTotalBytes() int64 {
return lc.TotalBytes
}
func (lc *LimitClass) GetLimitAction() LimitAction {
func (lc LimitClass) GetLimitAction() LimitAction {
return lc.LimitAction
}
@ -74,6 +79,8 @@ func (lc LimitClass) String() string {
return string(out)
}
var _ BandwidthClass = (*LimitClass)(nil)
func (str *Store) CreateLimitClass(lc *LimitClass, trx *sqlx.Tx) (int, error) {
stmt, err := trx.Prepare("insert into limit_classes (share_mode, backend_mode, environments, shares, reserved_shares, unique_names, period_minutes, rx_bytes, tx_bytes, total_bytes, limit_action) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) returning id")
if err != nil {
@ -85,3 +92,11 @@ func (str *Store) CreateLimitClass(lc *LimitClass, trx *sqlx.Tx) (int, error) {
}
return id, nil
}
func (str *Store) GetLimitClass(lcId int, trx *sqlx.Tx) (*LimitClass, error) {
lc := &LimitClass{}
if err := trx.QueryRowx("select * from limit_classes where id = $1", lcId).StructScan(lc); err != nil {
return nil, errors.Wrap(err, "error selecting limit_class by id")
}
return lc, nil
}

View File

@ -65,6 +65,14 @@ func (str *Store) FindShareWithToken(shrToken string, tx *sqlx.Tx) (*Share, erro
return shr, nil
}
func (str *Store) FindShareWithTokenEvenIfDeleted(shrToken string, tx *sqlx.Tx) (*Share, error) {
shr := &Share{}
if err := tx.QueryRowx("select * from shares where token = $1", shrToken).StructScan(shr); err != nil {
return nil, errors.Wrap(err, "error selecting share by token, even if deleted")
}
return shr, nil
}
func (str *Store) ShareWithTokenExists(shrToken string, tx *sqlx.Tx) (bool, error) {
count := 0
if err := tx.QueryRowx("select count(0) from shares where token = $1 and not deleted", shrToken).Scan(&count); err != nil {