zrok/controller/store/limitClass.go

176 lines
4.3 KiB
Go

package store
import (
"fmt"
"github.com/jmoiron/sqlx"
"github.com/openziti/zrok/sdk/golang/sdk"
"github.com/openziti/zrok/util"
"github.com/pkg/errors"
)
const Unlimited = -1
type BaseLimitClass interface {
IsGlobal() bool
GetLimitClassId() int
String() string
}
type ResourceCountClass interface {
BaseLimitClass
GetEnvironments() int
GetShares() int
GetReservedShares() int
GetUniqueNames() int
GetShareFrontends() int
}
type BandwidthClass interface {
BaseLimitClass
IsScoped() bool
GetBackendMode() sdk.BackendMode
GetPeriodMinutes() int
GetRxBytes() int64
GetTxBytes() int64
GetTotalBytes() int64
GetLimitAction() LimitAction
}
type LimitClass struct {
Model
Label *string
BackendMode *sdk.BackendMode
Environments int
Shares int
ReservedShares int
UniqueNames int
ShareFrontends int
PeriodMinutes int
RxBytes int64
TxBytes int64
TotalBytes int64
LimitAction LimitAction
}
func (lc LimitClass) IsGlobal() bool {
return false
}
func (lc LimitClass) IsScoped() bool {
return lc.BackendMode != nil
}
func (lc LimitClass) GetLimitClassId() int {
return lc.Id
}
func (lc LimitClass) GetEnvironments() int {
return lc.Environments
}
func (lc LimitClass) GetShares() int {
return lc.Shares
}
func (lc LimitClass) GetReservedShares() int {
return lc.ReservedShares
}
func (lc LimitClass) GetUniqueNames() int {
return lc.UniqueNames
}
func (lc LimitClass) GetShareFrontends() int {
return lc.ShareFrontends
}
func (lc LimitClass) GetBackendMode() sdk.BackendMode {
if lc.BackendMode == nil {
return ""
}
return *lc.BackendMode
}
func (lc LimitClass) GetPeriodMinutes() int {
return lc.PeriodMinutes
}
func (lc LimitClass) GetRxBytes() int64 {
return lc.RxBytes
}
func (lc LimitClass) GetTxBytes() int64 {
return lc.TxBytes
}
func (lc LimitClass) GetTotalBytes() int64 {
return lc.TotalBytes
}
func (lc LimitClass) GetLimitAction() LimitAction {
return lc.LimitAction
}
func (lc LimitClass) String() string {
out := "LimitClass<"
if lc.Label != nil && *lc.Label != "" {
out += "'" + *lc.Label + "'"
} else {
out += fmt.Sprintf("#%d", lc.Id)
}
if lc.BackendMode != nil {
out += fmt.Sprintf(", backendMode: '%s'", *lc.BackendMode)
}
if lc.Environments > Unlimited {
out += fmt.Sprintf(", environments: %d", lc.Environments)
}
if lc.Shares > Unlimited {
out += fmt.Sprintf(", shares: %d", lc.Shares)
}
if lc.ReservedShares > Unlimited {
out += fmt.Sprintf(", reservedShares: %d", lc.ReservedShares)
}
if lc.UniqueNames > Unlimited {
out += fmt.Sprintf(", uniqueNames: %d", lc.UniqueNames)
}
if lc.ShareFrontends > Unlimited {
out += fmt.Sprintf(", shareFrontends: %d", lc.ShareFrontends)
}
if lc.RxBytes > Unlimited || lc.TxBytes > Unlimited || lc.TotalBytes > Unlimited {
out += fmt.Sprintf(", periodMinutes: %d", lc.PeriodMinutes)
}
if lc.RxBytes > Unlimited {
out += fmt.Sprintf(", rxBytes: %v", util.BytesToSize(lc.RxBytes))
}
if lc.TxBytes > Unlimited {
out += fmt.Sprintf(", txBytes: %v", util.BytesToSize(lc.TxBytes))
}
if lc.TotalBytes > Unlimited {
out += fmt.Sprintf(", totalBytes: %v", util.BytesToSize(lc.TotalBytes))
}
out += fmt.Sprintf(", limitAction: '%v'>", lc.LimitAction)
return out
}
var _ BandwidthClass = (*LimitClass)(nil)
func (str *Store) CreateLimitClass(lc *LimitClass, trx *sqlx.Tx) (int, error) {
stmt, err := trx.Prepare("insert into limit_classes (label, backend_mode, environments, shares, reserved_shares, unique_names, share_frontends, period_minutes, rx_bytes, tx_bytes, total_bytes, limit_action) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) returning id")
if err != nil {
return 0, errors.Wrap(err, "error preparing limit_classes insert statement")
}
var id int
if err := stmt.QueryRow(lc.Label, lc.BackendMode, lc.Environments, lc.Shares, lc.ReservedShares, lc.UniqueNames, lc.ShareFrontends, lc.PeriodMinutes, lc.RxBytes, lc.TxBytes, lc.TotalBytes, lc.LimitAction).Scan(&id); err != nil {
return 0, errors.Wrap(err, "error executing limit_classes insert statement")
}
return id, nil
}
func (str *Store) GetLimitClass(lcId int, trx *sqlx.Tx) (*LimitClass, error) {
lc := &LimitClass{}
if err := trx.QueryRowx("select * from limit_classes where id = $1", lcId).StructScan(lc); err != nil {
return nil, errors.Wrap(err, "error selecting limit_class by id")
}
return lc, nil
}