diff --git a/controller/limits/agent.go b/controller/limits/agent.go index 2f9185c5..5c144473 100644 --- a/controller/limits/agent.go +++ b/controller/limits/agent.go @@ -1,10 +1,13 @@ package limits import ( + "fmt" + "github.com/jmoiron/sqlx" "github.com/openziti/zrok/controller/metrics" "github.com/openziti/zrok/controller/store" "github.com/openziti/zrok/controller/zrokEdgeSdk" "github.com/openziti/zrok/util" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "time" ) @@ -66,58 +69,168 @@ mainLoop: } } -func (a *Agent) enforce(u *metrics.Usage) { - acctPeriod := 24 * time.Hour - acctLimit := DefaultBandwidthPerPeriod() - if a.cfg.Bandwidth != nil && a.cfg.Bandwidth.PerAccount != nil { - acctLimit = a.cfg.Bandwidth.PerAccount - } - if acctLimit.Period > 0 { - acctPeriod = acctLimit.Period - } - acctRx, acctTx, err := a.ifx.totalRxTxForAccount(u.AccountId, acctPeriod) +func (a *Agent) enforce(u *metrics.Usage) error { + trx, err := a.str.Begin() if err != nil { - logrus.Error(err) - } - if acctLimit.Warning.Rx != Unlimited && acctRx > acctLimit.Warning.Rx { - logrus.Warnf("'%v': account over rx warning limit '%v' at '%v'", u.ShareToken, util.BytesToSize(acctLimit.Warning.Rx), util.BytesToSize(acctRx)) - } - if acctLimit.Warning.Tx != Unlimited && acctTx > acctLimit.Warning.Tx { - logrus.Warnf("'%v': account over tx warning limit '%v' at '%v'", u.ShareToken, util.BytesToSize(acctLimit.Warning.Tx), util.BytesToSize(acctTx)) - } - if acctLimit.Warning.Total != Unlimited && acctTx+acctRx > acctLimit.Warning.Total { - logrus.Warnf("'%v': account over total warning limit '%v' at '%v'", u.ShareToken, util.BytesToSize(acctLimit.Warning.Total), util.BytesToSize(acctRx+acctTx)) + return errors.Wrap(err, "error starting transaction") } + defer func() { _ = trx.Rollback() }() - envPeriod := 24 * time.Hour - envLimit := DefaultBandwidthPerPeriod() - if a.cfg.Bandwidth != nil && a.cfg.Bandwidth.PerEnvironment != nil { - envLimit = a.cfg.Bandwidth.PerEnvironment - } - if envLimit.Period > 0 { - envPeriod = envLimit.Period - } - envRx, envTx, err := a.ifx.totalRxTxForEnvironment(u.EnvironmentId, envPeriod) - if err != nil { + if enforce, warning, err := a.checkAccountLimits(u, trx); err == nil { + if enforce { + logrus.Warn("enforcing account limit") + } else if warning { + logrus.Warn("reporting account warning") + } else { + if enforce, warning, err := a.checkEnvironmentLimit(u, trx); err == nil { + if enforce { + logrus.Warn("enforcing environment limit") + } else if warning { + logrus.Warn("reporting environment warning") + } else { + if enforce, warning, err := a.checkShareLimit(u); err == nil { + if enforce { + logrus.Warn("enforcing share limit") + } else if warning { + logrus.Warn("reporting share warning") + } + } else { + logrus.Error(err) + } + } + } else { + logrus.Error(err) + } + } + } else { logrus.Error(err) } - sharePeriod := 24 * time.Hour - shareLimit := DefaultBandwidthPerPeriod() - if a.cfg.Bandwidth != nil && a.cfg.Bandwidth.PerShare != nil { - shareLimit = a.cfg.Bandwidth.PerShare - } - if shareLimit.Period > 0 { - sharePeriod = shareLimit.Period - } - shareRx, shareTx, err := a.ifx.totalRxTxForShare(u.ShareToken, sharePeriod) - if err != nil { - logrus.Error(err) - } - logrus.Infof("'%v': acct:{rx: %v, tx: %v}/%v, env:{rx: %v, tx: %v}/%v, share:{rx: %v, tx: %v}/%v", - u.ShareToken, - util.BytesToSize(acctRx), util.BytesToSize(acctTx), acctPeriod, - util.BytesToSize(envRx), util.BytesToSize(envTx), envPeriod, - util.BytesToSize(shareRx), util.BytesToSize(shareTx), sharePeriod, - ) + return nil +} + +func (a *Agent) checkAccountLimits(u *metrics.Usage, trx *sqlx.Tx) (enforce, warning bool, err error) { + acct, err := a.str.GetAccount(int(u.AccountId), trx) + if err != nil { + return false, false, errors.Wrapf(err, "error getting account '%d'", u.AccountId) + } + + period := 24 * time.Hour + limit := DefaultBandwidthPerPeriod() + if a.cfg.Bandwidth != nil && a.cfg.Bandwidth.PerAccount != nil { + limit = a.cfg.Bandwidth.PerAccount + } + if limit.Period > 0 { + period = limit.Period + } + rx, tx, err := a.ifx.totalRxTxForAccount(u.AccountId, period) + if err != nil { + logrus.Error(err) + } + + enforce, warning = a.checkLimit(limit, rx, tx) + if enforce || warning { + logrus.Warnf("'%v': %v", acct.Email, a.describeLimit(limit, rx, tx)) + } + + return enforce, warning, nil +} + +func (a *Agent) checkEnvironmentLimit(u *metrics.Usage, trx *sqlx.Tx) (enforce, warning bool, err error) { + env, err := a.str.GetEnvironment(int(u.EnvironmentId), trx) + if err != nil { + return false, false, errors.Wrapf(err, "error getting account '%d'", u.EnvironmentId) + } + + period := 24 * time.Hour + limit := DefaultBandwidthPerPeriod() + if a.cfg.Bandwidth != nil && a.cfg.Bandwidth.PerEnvironment != nil { + limit = a.cfg.Bandwidth.PerEnvironment + } + if limit.Period > 0 { + period = limit.Period + } + rx, tx, err := a.ifx.totalRxTxForEnvironment(u.EnvironmentId, period) + if err != nil { + logrus.Error(err) + } + + enforce, warning = a.checkLimit(limit, rx, tx) + if enforce || warning { + logrus.Warnf("'%v': %v", env.ZId, a.describeLimit(limit, rx, tx)) + } + + return enforce, warning, nil +} + +func (a *Agent) checkShareLimit(u *metrics.Usage) (enforce, warning bool, err error) { + period := 24 * time.Hour + limit := DefaultBandwidthPerPeriod() + if a.cfg.Bandwidth != nil && a.cfg.Bandwidth.PerShare != nil { + limit = a.cfg.Bandwidth.PerShare + } + if limit.Period > 0 { + period = limit.Period + } + rx, tx, err := a.ifx.totalRxTxForShare(u.ShareToken, period) + if err != nil { + logrus.Error(err) + } + + enforce, warning = a.checkLimit(limit, rx, tx) + if enforce || warning { + logrus.Warnf("'%v': %v", u.ShareToken, a.describeLimit(limit, rx, tx)) + } + + return enforce, warning, nil +} + +func (a *Agent) checkLimit(cfg *BandwidthPerPeriod, rx, tx int64) (enforce, warning bool) { + if cfg.Limit.Rx != Unlimited && rx > cfg.Limit.Rx { + return true, false + } + if cfg.Limit.Tx != Unlimited && tx > cfg.Limit.Tx { + return true, false + } + if cfg.Limit.Total != Unlimited && rx+tx > cfg.Limit.Total { + return true, false + } + + if cfg.Warning.Rx != Unlimited && rx > cfg.Warning.Rx { + return false, true + } + if cfg.Warning.Tx != Unlimited && tx > cfg.Warning.Tx { + return false, true + } + if cfg.Warning.Total != Unlimited && rx+tx > cfg.Warning.Total { + return false, true + } + + return false, false +} + +func (a *Agent) describeLimit(cfg *BandwidthPerPeriod, rx, tx int64) string { + out := "" + + if cfg.Limit.Rx != Unlimited && rx > cfg.Limit.Rx { + out += fmt.Sprintf("['%v' over rx limit '%v']", util.BytesToSize(rx), util.BytesToSize(cfg.Limit.Rx)) + } + if cfg.Limit.Tx != Unlimited && tx > cfg.Limit.Tx { + out += fmt.Sprintf("['%v' over tx limit '%v']", util.BytesToSize(tx), util.BytesToSize(cfg.Limit.Tx)) + } + if cfg.Limit.Total != Unlimited && rx+tx > cfg.Limit.Total { + out += fmt.Sprintf("['%v' over total limit '%v']", util.BytesToSize(rx+tx), util.BytesToSize(cfg.Limit.Total)) + } + + if cfg.Warning.Rx != Unlimited && rx > cfg.Warning.Rx { + out += fmt.Sprintf("['%v' over rx warning '%v']", util.BytesToSize(rx), util.BytesToSize(cfg.Warning.Rx)) + } + if cfg.Warning.Tx != Unlimited && tx > cfg.Warning.Tx { + out += fmt.Sprintf("['%v' over tx warning '%v']", util.BytesToSize(tx), util.BytesToSize(cfg.Warning.Tx)) + } + if cfg.Warning.Total != Unlimited && rx+tx > cfg.Warning.Total { + out += fmt.Sprintf("['%v' over total warning '%v']", util.BytesToSize(rx+tx), util.BytesToSize(cfg.Warning.Total)) + } + + return out }