diff --git a/controller/limits/agent.go b/controller/limits/agent.go index 02401e64..19bb0a66 100644 --- a/controller/limits/agent.go +++ b/controller/limits/agent.go @@ -113,27 +113,198 @@ func (a *Agent) enforce(u *metrics.Usage) error { if enforce, warning, err := a.checkAccountLimits(u, trx); err == nil { if enforce { - logrus.Warn("enforcing account limit") + enforced := false + var enforcedAt time.Time + if empty, err := a.str.IsAccountLimitJournalEmpty(int(u.AccountId), trx); err == nil && !empty { + if latest, err := a.str.FindLatestAccountLimitJournal(int(u.AccountId), trx); err == nil { + enforced = latest.Action == store.LimitAction + enforcedAt = latest.UpdatedAt + } + } - alje, err := a.str.FindLatestAccountLimitJournal(int(u.AccountId), trx) - if err != nil { - return err + if !enforced { + _, err := a.str.CreateAccountLimitJournal(&store.AccountLimitJournal{ + AccountId: int(u.AccountId), + RxBytes: u.BackendRx, + TxBytes: u.BackendTx, + Action: store.LimitAction, + }, trx) + if err != nil { + return err + } + + logrus.Warnf("enforcing account limit for '#%d': %v", u.AccountId, a.describeLimit(a.cfg.Bandwidth.PerAccount, u.BackendRx, u.BackendTx)) + + if err := trx.Commit(); err != nil { + return err + } + } else { + logrus.Debugf("already enforced limit for account '#%d' at %v", u.AccountId, enforcedAt) } } else if warning { - logrus.Warn("reporting account warning") + warned := false + var warnedAt time.Time + if empty, err := a.str.IsAccountLimitJournalEmpty(int(u.AccountId), trx); err == nil && !empty { + if latest, err := a.str.FindLatestAccountLimitJournal(int(u.AccountId), trx); err == nil { + warned = latest.Action == store.WarningAction || latest.Action == store.LimitAction + warnedAt = latest.UpdatedAt + } + } + + if !warned { + _, err := a.str.CreateAccountLimitJournal(&store.AccountLimitJournal{ + AccountId: int(u.AccountId), + RxBytes: u.BackendRx, + TxBytes: u.BackendTx, + Action: store.WarningAction, + }, trx) + if err != nil { + return err + } + + logrus.Warnf("warning account '#%d': %v", u.AccountId, a.describeLimit(a.cfg.Bandwidth.PerAccount, u.BackendRx, u.BackendTx)) + + if err := trx.Commit(); err != nil { + return err + } + } else { + logrus.Debugf("already warned account '#%d' at %v", u.AccountId, warnedAt) + } + } else { if enforce, warning, err := a.checkEnvironmentLimit(u, trx); err == nil { if enforce { - logrus.Warn("enforcing environment limit") + enforced := false + var enforcedAt time.Time + if empty, err := a.str.IsEnvironmentLimitJournalEmpty(int(u.EnvironmentId), trx); err == nil && !empty { + if latest, err := a.str.FindLatestEnvironmentLimitJournal(int(u.EnvironmentId), trx); err == nil { + enforced = latest.Action == store.LimitAction + enforcedAt = latest.UpdatedAt + } + } + + if !enforced { + _, err := a.str.CreateEnvironmentLimitJournal(&store.EnvironmentLimitJournal{ + EnvironmentId: int(u.EnvironmentId), + RxBytes: u.BackendRx, + TxBytes: u.BackendTx, + Action: store.LimitAction, + }, trx) + if err != nil { + return err + } + + logrus.Warnf("enforcing environment limit for environment '#%d': %v", u.EnvironmentId, a.describeLimit(a.cfg.Bandwidth.PerEnvironment, u.BackendRx, u.BackendTx)) + + if err := trx.Commit(); err != nil { + return err + } + } else { + logrus.Debugf("already enforced limit for environment '#%d' at %v", u.EnvironmentId, enforcedAt) + } + } else if warning { - logrus.Warn("reporting environment warning") + warned := false + var warnedAt time.Time + if empty, err := a.str.IsEnvironmentLimitJournalEmpty(int(u.EnvironmentId), trx); err == nil && !empty { + if latest, err := a.str.FindLatestEnvironmentLimitJournal(int(u.EnvironmentId), trx); err == nil { + warned = latest.Action == store.WarningAction || latest.Action == store.LimitAction + warnedAt = latest.UpdatedAt + } + } + + if !warned { + _, err := a.str.CreateEnvironmentLimitJournal(&store.EnvironmentLimitJournal{ + EnvironmentId: int(u.EnvironmentId), + RxBytes: u.BackendRx, + TxBytes: u.BackendTx, + Action: store.WarningAction, + }, trx) + if err != nil { + return err + } + + logrus.Warnf("warning environment '#%d': %v", u.EnvironmentId, a.describeLimit(a.cfg.Bandwidth.PerEnvironment, u.BackendRx, u.BackendTx)) + + if err := trx.Commit(); err != nil { + return err + } + } else { + logrus.Debugf("already warned environment '#%d' at %v", u.EnvironmentId, warnedAt) + } + } else { if enforce, warning, err := a.checkShareLimit(u); err == nil { if enforce { - logrus.Warn("enforcing share limit") + shr, err := a.str.FindShareWithToken(u.ShareToken, trx) + if err != nil { + return err + } + + enforced := false + var enforcedAt time.Time + if empty, err := a.str.IsShareLimitJournalEmpty(shr.Id, trx); err == nil && !empty { + if latest, err := a.str.FindLatestShareLimitJournal(shr.Id, trx); err == nil { + enforced = latest.Action == store.LimitAction + enforcedAt = latest.UpdatedAt + } + } + + if !enforced { + _, err := a.str.CreateShareLimitJournal(&store.ShareLimitJournal{ + ShareId: shr.Id, + RxBytes: u.BackendRx, + TxBytes: u.BackendTx, + Action: store.LimitAction, + }, trx) + if err != nil { + return err + } + + logrus.Warnf("enforcing share limit for share '%v': %v", shr.Token, a.describeLimit(a.cfg.Bandwidth.PerShare, u.BackendRx, u.BackendTx)) + + if err := trx.Commit(); err != nil { + return err + } + } else { + logrus.Debugf("already enforced limit for share '%v' at %v", shr.Token, enforcedAt) + } + } else if warning { - logrus.Warn("reporting share warning") + shr, err := a.str.FindShareWithToken(u.ShareToken, trx) + if err != nil { + return err + } + + warned := false + var warnedAt time.Time + if empty, err := a.str.IsShareLimitJournalEmpty(shr.Id, trx); err == nil && !empty { + if latest, err := a.str.FindLatestShareLimitJournal(shr.Id, trx); err == nil { + warned = latest.Action == store.WarningAction || latest.Action == store.LimitAction + warnedAt = latest.UpdatedAt + } + } + + if !warned { + _, err := a.str.CreateShareLimitJournal(&store.ShareLimitJournal{ + ShareId: shr.Id, + RxBytes: u.BackendRx, + TxBytes: u.BackendTx, + Action: store.WarningAction, + }, trx) + if err != nil { + return err + } + + logrus.Warnf("warning share '%v': %v", shr.Token, a.describeLimit(a.cfg.Bandwidth.PerShare, u.BackendRx, u.BackendTx)) + + if err := trx.Commit(); err != nil { + return err + } + } else { + logrus.Debugf("already warned share '%v' at %v", shr.Token, warnedAt) + } } } else { logrus.Error(err) diff --git a/controller/store/accountLimitJournal.go b/controller/store/accountLimitJournal.go index 2c719e73..9bf6881b 100644 --- a/controller/store/accountLimitJournal.go +++ b/controller/store/accountLimitJournal.go @@ -10,11 +10,11 @@ type AccountLimitJournal struct { AccountId int RxBytes int64 TxBytes int64 - Action string + Action LimitJournalAction } -func (self *Store) CreateAccountLimitJournal(j *AccountLimitJournal, tx *sqlx.Tx) (int, error) { - stmt, err := tx.Prepare("insert into account_limit_journal (account_id, rx_bytes, tx_bytes, action) values ($1, $2, $3, $4) returning id") +func (self *Store) CreateAccountLimitJournal(j *AccountLimitJournal, trx *sqlx.Tx) (int, error) { + stmt, err := trx.Prepare("insert into account_limit_journal (account_id, rx_bytes, tx_bytes, action) values ($1, $2, $3, $4) returning id") if err != nil { return 0, errors.Wrap(err, "error preparing account_limit_journal insert statement") } @@ -25,17 +25,17 @@ func (self *Store) CreateAccountLimitJournal(j *AccountLimitJournal, tx *sqlx.Tx return id, nil } -func (self *Store) IsAccountLimitJournalEmpty(acctId int, tx *sqlx.Tx) (bool, error) { +func (self *Store) IsAccountLimitJournalEmpty(acctId int, trx *sqlx.Tx) (bool, error) { count := 0 - if err := tx.QueryRowx("select count(0) from account_limit_journal where account_id = $1", acctId).Scan(&count); err != nil { + if err := trx.QueryRowx("select count(0) from account_limit_journal where account_id = $1", acctId).Scan(&count); err != nil { return false, err } return count == 0, nil } -func (self *Store) FindLatestAccountLimitJournal(acctId int, tx *sqlx.Tx) (*AccountLimitJournal, error) { +func (self *Store) FindLatestAccountLimitJournal(acctId int, trx *sqlx.Tx) (*AccountLimitJournal, error) { j := &AccountLimitJournal{} - if err := tx.QueryRowx("select * from account_limit_journal where account_id = $1 order by created_at desc limit 1", acctId).StructScan(j); err != nil { + if err := trx.QueryRowx("select * from account_limit_journal where account_id = $1 order by id desc limit 1", acctId).StructScan(j); err != nil { return nil, errors.Wrap(err, "error finding account_limit_journal by account_id") } return j, nil diff --git a/controller/store/accountLimitJournal_test.go b/controller/store/accountLimitJournal_test.go index ac12026a..b4e03532 100644 --- a/controller/store/accountLimitJournal_test.go +++ b/controller/store/accountLimitJournal_test.go @@ -21,7 +21,7 @@ func TestAccountLimitJournal(t *testing.T) { acctId, err := str.CreateAccount(&Account{Email: "nobody@nowehere.com", Salt: "salt", Password: "password", Token: "token", Limitless: false, Deleted: false}, trx) assert.Nil(t, err) - _, err = str.CreateAccountLimitJournal(&AccountLimitJournal{AccountId: acctId, RxBytes: 1024, TxBytes: 2048, Action: "warning"}, trx) + _, err = str.CreateAccountLimitJournal(&AccountLimitJournal{AccountId: acctId, RxBytes: 1024, TxBytes: 2048, Action: WarningAction}, trx) assert.Nil(t, err) aljEmpty, err = str.IsAccountLimitJournalEmpty(acctId, trx) @@ -33,9 +33,9 @@ func TestAccountLimitJournal(t *testing.T) { assert.NotNil(t, latestAlj) assert.Equal(t, int64(1024), latestAlj.RxBytes) assert.Equal(t, int64(2048), latestAlj.TxBytes) - assert.Equal(t, "warning", latestAlj.Action) + assert.Equal(t, WarningAction, latestAlj.Action) - _, err = str.CreateAccountLimitJournal(&AccountLimitJournal{AccountId: acctId, RxBytes: 2048, TxBytes: 4096, Action: "limit"}, trx) + _, err = str.CreateAccountLimitJournal(&AccountLimitJournal{AccountId: acctId, RxBytes: 2048, TxBytes: 4096, Action: LimitAction}, trx) assert.Nil(t, err) latestAlj, err = str.FindLatestAccountLimitJournal(acctId, trx) @@ -43,5 +43,5 @@ func TestAccountLimitJournal(t *testing.T) { assert.NotNil(t, latestAlj) assert.Equal(t, int64(2048), latestAlj.RxBytes) assert.Equal(t, int64(4096), latestAlj.TxBytes) - assert.Equal(t, "limit", latestAlj.Action) + assert.Equal(t, LimitAction, latestAlj.Action) } diff --git a/controller/store/environmentLimitJournal.go b/controller/store/environmentLimitJournal.go index 077d03ac..73cbbd62 100644 --- a/controller/store/environmentLimitJournal.go +++ b/controller/store/environmentLimitJournal.go @@ -10,11 +10,11 @@ type EnvironmentLimitJournal struct { EnvironmentId int RxBytes int64 TxBytes int64 - Action string + Action LimitJournalAction } -func (self *Store) CreateEnvironmentLimitJournal(j *EnvironmentLimitJournal, tx *sqlx.Tx) (int, error) { - stmt, err := tx.Prepare("insert into environment_limit_journal (environment_id, rx_bytes, tx_bytes, action) values ($1, $2, $3, $4) returning id") +func (self *Store) CreateEnvironmentLimitJournal(j *EnvironmentLimitJournal, trx *sqlx.Tx) (int, error) { + stmt, err := trx.Prepare("insert into environment_limit_journal (environment_id, rx_bytes, tx_bytes, action) values ($1, $2, $3, $4) returning id") if err != nil { return 0, errors.Wrap(err, "error preparing environment_limit_journal insert statement") } @@ -25,9 +25,17 @@ func (self *Store) CreateEnvironmentLimitJournal(j *EnvironmentLimitJournal, tx return id, nil } -func (self *Store) FindLatestEnvironmentLimitJournal(envId int, tx *sqlx.Tx) (*EnvironmentLimitJournal, error) { +func (self *Store) IsEnvironmentLimitJournalEmpty(envId int, trx *sqlx.Tx) (bool, error) { + count := 0 + if err := trx.QueryRowx("select count(0) from environment_limit_journal where environment_id = $1", envId).Scan(&count); err != nil { + return false, err + } + return count == 0, nil +} + +func (self *Store) FindLatestEnvironmentLimitJournal(envId int, trx *sqlx.Tx) (*EnvironmentLimitJournal, error) { j := &EnvironmentLimitJournal{} - if err := tx.QueryRowx("select * from environment_limit_journal where environment_id = $1", envId).StructScan(j); err != nil { + if err := trx.QueryRowx("select * from environment_limit_journal where environment_id = $1 order by created_at desc limit 1", envId).StructScan(j); err != nil { return nil, errors.Wrap(err, "error finding environment_limit_journal by environment_id") } return j, nil diff --git a/controller/store/model.go b/controller/store/model.go new file mode 100644 index 00000000..c5dae5eb --- /dev/null +++ b/controller/store/model.go @@ -0,0 +1,9 @@ +package store + +type LimitJournalAction string + +const ( + LimitAction LimitJournalAction = "limit" + WarningAction LimitJournalAction = "warning" + ClearAction LimitJournalAction = "clear" +) diff --git a/controller/store/shareLimitJournal.go b/controller/store/shareLimitJournal.go index fe97733a..8ad3f6df 100644 --- a/controller/store/shareLimitJournal.go +++ b/controller/store/shareLimitJournal.go @@ -10,7 +10,7 @@ type ShareLimitJournal struct { ShareId int RxBytes int64 TxBytes int64 - Action string + Action LimitJournalAction } func (self *Store) CreateShareLimitJournal(j *ShareLimitJournal, tx *sqlx.Tx) (int, error) { @@ -25,9 +25,17 @@ func (self *Store) CreateShareLimitJournal(j *ShareLimitJournal, tx *sqlx.Tx) (i return id, nil } +func (self *Store) IsShareLimitJournalEmpty(shrId int, trx *sqlx.Tx) (bool, error) { + count := 0 + if err := trx.QueryRowx("select count(0) from share_limit_journal where share_id = $1", shrId).Scan(&count); err != nil { + return false, err + } + return count == 0, nil +} + func (self *Store) FindLatestShareLimitJournal(shrId int, tx *sqlx.Tx) (*ShareLimitJournal, error) { j := &ShareLimitJournal{} - if err := tx.QueryRowx("select * from share_limit_journal where share_id = $1", shrId).StructScan(j); err != nil { + if err := tx.QueryRowx("select * from share_limit_journal where share_id = $1 order by created_at desc limit 1", shrId).StructScan(j); err != nil { return nil, errors.Wrap(err, "error finding share_limit_journal by share_id") } return j, nil