limit class filtering (#606)

This commit is contained in:
Michael Quigley 2024-05-23 14:08:14 -04:00
parent 58eb8bfcca
commit 896a4a7845
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
4 changed files with 76 additions and 12 deletions

View File

@ -7,6 +7,7 @@ import (
"github.com/openziti/zrok/controller/metrics"
"github.com/openziti/zrok/controller/store"
"github.com/openziti/zrok/controller/zrokEdgeSdk"
"github.com/openziti/zrok/sdk/golang/sdk"
"github.com/openziti/zrok/util"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@ -94,7 +95,7 @@ func (a *Agent) CanCreateEnvironment(acctId int, trx *sqlx.Tx) (bool, error) {
return true, nil
}
func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, trx *sqlx.Tx) (bool, error) {
func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, shareMode sdk.ShareMode, backendMode sdk.BackendMode, trx *sqlx.Tx) (bool, error) {
if a.cfg.Enforcing {
if err := a.str.LimitCheckLock(acctId, trx); err != nil {
return false, err
@ -123,6 +124,16 @@ func (a *Agent) CanCreateShare(acctId, envId int, reserved, uniqueName bool, trx
return false, err
}
alc, err := a.str.FindLimitClassesForAccount(acctId, trx)
if err != nil {
logrus.Errorf("error finding limit classes for account with id '%d': %v", acctId, err)
return false, err
}
sortLimitClasses(alc)
if len(alc) > 0 {
logrus.Infof("selected limit class: %v", alc[0])
}
if a.cfg.Shares > Unlimited || (reserved && a.cfg.ReservedShares > Unlimited) || (reserved && uniqueName && a.cfg.UniqueNames > Unlimited) {
envs, err := a.str.FindEnvironmentsForAccount(acctId, trx)
if err != nil {

View File

@ -0,0 +1,38 @@
package limits
import (
"github.com/openziti/zrok/controller/store"
"sort"
)
func sortLimitClasses(lcs []*store.LimitClass) {
sort.Slice(lcs, func(i, j int) bool {
ipoints := limitScopePoints(lcs[i]) + modePoints(lcs[i])
jpoints := limitScopePoints(lcs[j]) + modePoints(lcs[j])
return ipoints > jpoints
})
}
func limitScopePoints(lc *store.LimitClass) int {
points := 0
switch lc.LimitScope {
case store.AccountLimitScope:
points += 1000
case store.EnvironmentLimitScope:
points += 100
case store.ShareLimitScope:
points += 10
}
return points
}
func modePoints(lc *store.LimitClass) int {
points := 0
if lc.BackendMode != "" {
points += 1
}
if lc.ShareMode != "" {
points += 1
}
return points
}

View File

@ -49,7 +49,9 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
return share.NewShareInternalServerError()
}
if err := h.checkLimits(envId, principal, params.Body.Reserved, params.Body.UniqueName != "", trx); err != nil {
shareMode := sdk.ShareMode(params.Body.ShareMode)
backendMode := sdk.BackendMode(params.Body.BackendMode)
if err := h.checkLimits(envId, principal, params.Body.Reserved, params.Body.UniqueName != "", shareMode, backendMode, trx); err != nil {
logrus.Errorf("limits error: %v", err)
return share.NewShareUnauthorized()
}
@ -190,10 +192,10 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
})
}
func (h *shareHandler) checkLimits(envId int, principal *rest_model_zrok.Principal, reserved, uniqueName bool, trx *sqlx.Tx) error {
func (h *shareHandler) checkLimits(envId int, principal *rest_model_zrok.Principal, reserved, uniqueName bool, shareMode sdk.ShareMode, backendMode sdk.BackendMode, trx *sqlx.Tx) error {
if !principal.Limitless {
if limitsAgent != nil {
ok, err := limitsAgent.CanCreateShare(int(principal.ID), envId, reserved, uniqueName, trx)
ok, err := limitsAgent.CanCreateShare(int(principal.ID), envId, reserved, uniqueName, shareMode, backendMode, trx)
if err != nil {
return errors.Wrapf(err, "error checking share limits for '%v'", principal.Email)
}

View File

@ -1,6 +1,7 @@
package store
import (
"encoding/json"
"github.com/jmoiron/sqlx"
"github.com/openziti/zrok/sdk/golang/sdk"
"github.com/pkg/errors"
@ -12,12 +13,24 @@ type LimitClass struct {
LimitAction LimitAction
ShareMode sdk.ShareMode
BackendMode sdk.BackendMode
Shares int
ReservedShares int
UniqueNames int
PeriodMinutes int
RxBytes int64
TxBytes int64
TotalBytes int64
}
func (lc LimitClass) String() string {
out, err := json.MarshalIndent(&lc, "", " ")
if err != nil {
return ""
}
return string(out)
}
func (str *Store) CreateLimitClass(lc *LimitClass, trx *sqlx.Tx) (int, error) {
stmt, err := trx.Prepare("insert into limit_classes (limit_scope, limit_action, share_mode, backend_mode, period_minutes, rx_bytes, tx_bytes, total_bytes) values ($1, $2, $3, $4, $5, $6, $7, $8) returning id")
if err != nil {