package limits import ( "github.com/jmoiron/sqlx" "github.com/openziti/zrok/controller/emailUi" "github.com/openziti/zrok/controller/metrics" "github.com/openziti/zrok/controller/store" "github.com/openziti/zrok/controller/zrokEdgeSdk" "github.com/openziti/zrok/sdk/golang/sdk" "github.com/openziti/zrok/util" "github.com/pkg/errors" "github.com/sirupsen/logrus" "reflect" "time" ) type Agent struct { cfg *Config ifx *influxReader zCfg *zrokEdgeSdk.Config str *store.Store queue chan *metrics.Usage warningActions []AccountAction limitActions []AccountAction relaxActions []AccountAction close chan struct{} join chan struct{} } func NewAgent(cfg *Config, ifxCfg *metrics.InfluxConfig, zCfg *zrokEdgeSdk.Config, emailCfg *emailUi.Config, str *store.Store) (*Agent, error) { a := &Agent{ cfg: cfg, ifx: newInfluxReader(ifxCfg), zCfg: zCfg, str: str, queue: make(chan *metrics.Usage, 1024), warningActions: []AccountAction{newWarningAction(emailCfg, str)}, limitActions: []AccountAction{newLimitAction(str, zCfg)}, relaxActions: []AccountAction{newRelaxAction(str, zCfg)}, close: make(chan struct{}), join: make(chan struct{}), } return a, nil } func (a *Agent) Start() { go a.run() } func (a *Agent) Stop() { close(a.close) <-a.join } func (a *Agent) CanCreateEnvironment(acctId int, trx *sqlx.Tx) (bool, error) { if a.cfg.Enforcing { if err := a.str.LimitCheckLock(acctId, trx); err != nil { return false, err } ul, err := a.getUserLimits(acctId, trx) if err != nil { return false, err } if ul.resource.GetEnvironments() > store.Unlimited { envs, err := a.str.FindEnvironmentsForAccount(acctId, trx) if err != nil { return false, err } if len(envs)+1 > a.cfg.Environments { return false, nil } } } return true, nil } func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, _ sdk.ShareMode, backendMode sdk.BackendMode, trx *sqlx.Tx) (bool, error) { if a.cfg.Enforcing { if err := a.str.LimitCheckLock(acctId, trx); err != nil { return false, err } ul, err := a.getUserLimits(acctId, trx) if err != nil { return false, err } if scopedBwc, found := ul.scopes[backendMode]; found { latestScopedJe, err := a.isBandwidthClassLimitedForAccount(acctId, scopedBwc, trx) if err != nil { return false, err } if latestScopedJe != nil { return false, nil } } else { for _, bwc := range ul.bandwidth { latestJe, err := a.isBandwidthClassLimitedForAccount(acctId, bwc, trx) if err != nil { return false, err } if latestJe != nil { return false, nil } } } rc := ul.resource if scopeRc, found := ul.scopes[backendMode]; found { rc = scopeRc } if rc.GetShares() > store.Unlimited || (reserved && rc.GetReservedShares() > store.Unlimited) || (reserved && uniqueName && rc.GetUniqueNames() > store.Unlimited) { envs, err := a.str.FindEnvironmentsForAccount(acctId, trx) if err != nil { return false, err } total := 0 reserveds := 0 uniqueNames := 0 for i := range envs { shrs, err := a.str.FindSharesForEnvironment(envs[i].Id, trx) if err != nil { return false, errors.Wrapf(err, "unable to find shares for environment '%v'", envs[i].ZId) } total += len(shrs) for _, shr := range shrs { if shr.Reserved { reserveds++ } if shr.UniqueName { uniqueNames++ } } if total+1 > rc.GetShares() { logrus.Debugf("account '#%d', environment '%d' over shares limit '%d'", acctId, envId, a.cfg.Shares) return false, nil } if reserved && reserveds+1 > rc.GetReservedShares() { logrus.Debugf("account '#%d', environment '%d' over reserved shares limit '%d'", acctId, envId, a.cfg.ReservedShares) return false, nil } if reserved && uniqueName && uniqueNames+1 > rc.GetUniqueNames() { logrus.Debugf("account '#%d', environment '%d' over unique names limit '%d'", acctId, envId, a.cfg.UniqueNames) return false, nil } } } } return true, nil } func (a *Agent) CanAccessShare(shrId int, trx *sqlx.Tx) (bool, error) { if a.cfg.Enforcing { shr, err := a.str.GetShare(shrId, trx) if err != nil { return false, err } env, err := a.str.GetEnvironment(shr.EnvironmentId, trx) if err != nil { return false, err } if env.AccountId != nil { if err := a.str.LimitCheckLock(*env.AccountId, trx); err != nil { return false, err } ul, err := a.getUserLimits(*env.AccountId, trx) if err != nil { return false, err } if scopedBwc, found := ul.scopes[sdk.BackendMode(shr.BackendMode)]; found { latestScopedJe, err := a.isBandwidthClassLimitedForAccount(*env.AccountId, scopedBwc, trx) if err != nil { return false, err } if latestScopedJe != nil { return false, nil } } else { for _, bwc := range ul.bandwidth { latestJe, err := a.isBandwidthClassLimitedForAccount(*env.AccountId, bwc, trx) if err != nil { return false, err } if latestJe != nil { return false, nil } } } rc := ul.resource if scopeRc, found := ul.scopes[sdk.BackendMode(shr.BackendMode)]; found { rc = scopeRc } if rc.GetShareFrontends() > store.Unlimited { fes, err := a.str.FindFrontendsForPrivateShare(shr.Id, trx) if err != nil { return false, err } if len(fes)+1 > rc.GetShareFrontends() { logrus.Infof("account '#%d' over frontends per share limit '%d'", *env.AccountId, rc.GetReservedShares()) return false, nil } } } else { return false, nil } } return true, nil } func (a *Agent) Handle(u *metrics.Usage) error { logrus.Debugf("handling: %v", u) a.queue <- u return nil } func (a *Agent) run() { logrus.Info("started") defer logrus.Info("stopped") lastCycle := time.Now() mainLoop: for { select { case usage := <-a.queue: if usage.ShareToken != "" { if err := a.enforce(usage); err != nil { logrus.Errorf("error running enforcement: %v", err) } if time.Since(lastCycle) > a.cfg.Cycle { if err := a.relax(); err != nil { logrus.Errorf("error running relax cycle: %v", err) } lastCycle = time.Now() } } else { logrus.Warnf("not enforcing for usage with no share token: %v", usage.String()) } case <-time.After(a.cfg.Cycle): if err := a.relax(); err != nil { logrus.Errorf("error running relax cycle: %v", err) } lastCycle = time.Now() case <-a.close: close(a.join) break mainLoop } } } func (a *Agent) enforce(u *metrics.Usage) error { trx, err := a.str.Begin() if err != nil { return errors.Wrap(err, "error starting transaction") } defer func() { _ = trx.Rollback() }() acct, err := a.str.GetAccount(int(u.AccountId), trx) if err != nil { return err } if acct.Limitless { return nil } shr, err := a.str.FindShareWithTokenEvenIfDeleted(u.ShareToken, trx) if err != nil { return err } ul, err := a.getUserLimits(int(u.AccountId), trx) if err != nil { return err } exceededBwc, rxBytes, txBytes, err := a.anyBandwidthLimitExceeded(acct, u, ul.toBandwidthArray(sdk.BackendMode(shr.BackendMode))) if err != nil { return errors.Wrap(err, "error checking limit classes") } if exceededBwc != nil { latestJe, err := a.isBandwidthClassLimitedForAccount(int(u.AccountId), exceededBwc, trx) if err != nil { return err } if latestJe == nil { je := &store.BandwidthLimitJournalEntry{ AccountId: int(u.AccountId), RxBytes: rxBytes, TxBytes: txBytes, Action: exceededBwc.GetLimitAction(), } if !exceededBwc.IsGlobal() { lcId := exceededBwc.GetLimitClassId() je.LimitClassId = &lcId } if _, err := a.str.CreateBandwidthLimitJournalEntry(je, trx); err != nil { return err } acct, err := a.str.GetAccount(int(u.AccountId), trx) if err != nil { return err } switch exceededBwc.GetLimitAction() { case store.LimitLimitAction: for _, limitAction := range a.limitActions { if err := limitAction.HandleAccount(acct, rxBytes, txBytes, exceededBwc, ul, trx); err != nil { return errors.Wrapf(err, "%v", reflect.TypeOf(limitAction).String()) } } case store.WarningLimitAction: for _, warningAction := range a.warningActions { if err := warningAction.HandleAccount(acct, rxBytes, txBytes, exceededBwc, ul, trx); err != nil { return errors.Wrapf(err, "%v", reflect.TypeOf(warningAction).String()) } } } if err := trx.Commit(); err != nil { return err } } else { logrus.Debugf("limit '%v' already applied for '%v' (at: %v)", exceededBwc, acct.Email, latestJe.CreatedAt) } } return nil } func (a *Agent) relax() error { logrus.Debug("relaxing") trx, err := a.str.Begin() if err != nil { return errors.Wrap(err, "error starting transaction") } defer func() { _ = trx.Rollback() }() commit := false if bwjes, err := a.str.FindAllBandwidthLimitJournal(trx); err == nil { accounts := make(map[int]*store.Account) uls := make(map[int]*userLimits) accountPeriods := make(map[int]map[int]*periodBwValues) for _, bwje := range bwjes { if _, found := accounts[bwje.AccountId]; !found { if acct, err := a.str.GetAccount(bwje.AccountId, trx); err == nil { accounts[bwje.AccountId] = acct ul, err := a.getUserLimits(acct.Id, trx) if err != nil { return errors.Wrapf(err, "error getting user limits for '%v'", acct.Email) } uls[bwje.AccountId] = ul accountPeriods[bwje.AccountId] = make(map[int]*periodBwValues) } else { return err } } var bwc store.BandwidthClass if bwje.LimitClassId == nil { globalBwcs := newConfigBandwidthClasses(a.cfg.Bandwidth) if bwje.Action == store.WarningLimitAction { bwc = globalBwcs[0] } else { bwc = globalBwcs[1] } } else { lc, err := a.str.GetLimitClass(*bwje.LimitClassId, trx) if err != nil { return err } bwc = lc } if periods, accountFound := accountPeriods[bwje.AccountId]; accountFound { if _, periodFound := periods[bwc.GetPeriodMinutes()]; !periodFound { rx, tx, err := a.ifx.totalRxTxForAccount(int64(bwje.AccountId), time.Duration(bwc.GetPeriodMinutes())*time.Minute) if err != nil { return err } periods[bwc.GetPeriodMinutes()] = &periodBwValues{rx: rx, tx: tx} accountPeriods[bwje.AccountId] = periods } } else { return errors.New("accountPeriods corrupted") } used := accountPeriods[bwje.AccountId][bwc.GetPeriodMinutes()] if !a.transferBytesExceeded(used.rx, used.tx, bwc) { if bwc.GetLimitAction() == store.LimitLimitAction { logrus.Infof("relaxing limit '%v' for '%v'", bwc.String(), accounts[bwje.AccountId].Email) for _, action := range a.relaxActions { if err := action.HandleAccount(accounts[bwje.AccountId], used.rx, used.tx, bwc, uls[bwje.AccountId], trx); err != nil { return errors.Wrapf(err, "%v", reflect.TypeOf(action).String()) } } } else { logrus.Infof("relaxing warning '%v' for '%v'", bwc.String(), accounts[bwje.AccountId].Email) } if bwc.IsGlobal() { if err := a.str.DeleteBandwidthLimitJournalEntryForGlobal(bwje.AccountId, trx); err == nil { commit = true } else { logrus.Errorf("error deleting global bandwidth limit journal entry for '%v': %v", accounts[bwje.AccountId].Email, err) } } else { if err := a.str.DeleteBandwidthLimitJournalEntryForLimitClass(bwje.AccountId, *bwje.LimitClassId, trx); err == nil { commit = true } else { logrus.Errorf("error deleting bandwidth limit journal entry for '%v': %v", accounts[bwje.AccountId].Email, err) } } } else { logrus.Infof("'%v' still over limit: '%v' with rx: %v, tx: %v, total: %v", accounts[bwje.AccountId].Email, bwc, util.BytesToSize(used.rx), util.BytesToSize(used.tx), util.BytesToSize(used.rx+used.tx)) } } } else { return err } if commit { if err := trx.Commit(); err != nil { return err } } return nil } func (a *Agent) isBandwidthClassLimitedForAccount(acctId int, bwc store.BandwidthClass, trx *sqlx.Tx) (*store.BandwidthLimitJournalEntry, error) { if bwc.IsGlobal() { if empty, err := a.str.IsBandwidthLimitJournalEmptyForGlobal(acctId, trx); err == nil && !empty { je, err := a.str.FindLatestBandwidthLimitJournalForGlobal(acctId, trx) if err != nil { return nil, err } if je.Action == store.LimitLimitAction { logrus.Debugf("account '#%d' over bandwidth for global bandwidth class '%v'", acctId, bwc) return je, nil } } else if err != nil { return nil, err } } else { if empty, err := a.str.IsBandwidthLimitJournalEmptyForLimitClass(acctId, bwc.GetLimitClassId(), trx); err == nil && !empty { je, err := a.str.FindLatestBandwidthLimitJournalForLimitClass(acctId, bwc.GetLimitClassId(), trx) if err != nil { return nil, err } if je.Action == store.LimitLimitAction { logrus.Debugf("account '#%d' over bandwidth for limit class '%v'", acctId, bwc) return je, nil } } else if err != nil { return nil, err } } return nil, nil } func (a *Agent) anyBandwidthLimitExceeded(acct *store.Account, u *metrics.Usage, bwcs []store.BandwidthClass) (store.BandwidthClass, int64, int64, error) { periodBw := make(map[int]periodBwValues) var selectedLc store.BandwidthClass var rxBytes int64 var txBytes int64 for _, bwc := range bwcs { if _, found := periodBw[bwc.GetPeriodMinutes()]; !found { rx, tx, err := a.ifx.totalRxTxForAccount(u.AccountId, time.Minute*time.Duration(bwc.GetPeriodMinutes())) if err != nil { return nil, 0, 0, errors.Wrapf(err, "error getting rx/tx for account '%v'", acct.Email) } periodBw[bwc.GetPeriodMinutes()] = periodBwValues{rx: rx, tx: tx} } period := periodBw[bwc.GetPeriodMinutes()] if a.transferBytesExceeded(period.rx, period.tx, bwc) { selectedLc = bwc rxBytes = period.rx txBytes = period.tx } else { logrus.Debugf("'%v' limit ok '%v' with rx: %v, tx: %v, total: %v", acct.Email, bwc, util.BytesToSize(period.rx), util.BytesToSize(period.tx), util.BytesToSize(period.rx+period.tx)) } } if selectedLc != nil { logrus.Infof("'%v' exceeded limit '%v' with rx: %v, tx: %v, total: %v", acct.Email, selectedLc, util.BytesToSize(rxBytes), util.BytesToSize(txBytes), util.BytesToSize(rxBytes+txBytes)) } return selectedLc, rxBytes, txBytes, nil } func (a *Agent) transferBytesExceeded(rx, tx int64, bwc store.BandwidthClass) bool { if bwc.GetTxBytes() != store.Unlimited && tx >= bwc.GetTxBytes() { return true } if bwc.GetRxBytes() != store.Unlimited && rx >= bwc.GetRxBytes() { return true } if bwc.GetTotalBytes() != store.Unlimited && tx+rx >= bwc.GetTotalBytes() { return true } return false } type periodBwValues struct { rx int64 tx int64 }