share limits check owned by limits.Agent now (#277)

This commit is contained in:
Michael Quigley 2023-03-21 16:34:45 -04:00
parent 79e9f484dc
commit d0dd04a141
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
4 changed files with 53 additions and 43 deletions

View File

@ -52,7 +52,7 @@ func Run(inCfg *config.Config) error {
api.MetadataOverviewHandler = metadata.OverviewHandlerFunc(overviewHandler)
api.MetadataVersionHandler = metadata.VersionHandlerFunc(versionHandler)
api.ShareAccessHandler = newAccessHandler()
api.ShareShareHandler = newShareHandler(cfg.Limits)
api.ShareShareHandler = newShareHandler()
api.ShareUnaccessHandler = newUnaccessHandler()
api.ShareUnshareHandler = newUnshareHandler()
api.ShareUpdateShareHandler = newUpdateShareHandler()

View File

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"github.com/go-openapi/runtime/middleware"
"github.com/jmoiron/sqlx"
"github.com/openziti/zrok/controller/store"
"github.com/openziti/zrok/controller/zrokEdgeSdk"
"github.com/openziti/zrok/rest_model_zrok"
@ -20,14 +21,14 @@ func newEnableHandler() *enableHandler {
func (h *enableHandler) Handle(params environment.EnableParams, principal *rest_model_zrok.Principal) middleware.Responder {
// start transaction early; if it fails, don't bother creating ziti resources
tx, err := str.Begin()
trx, err := str.Begin()
if err != nil {
logrus.Errorf("error starting transaction for user '%v': %v", principal.Email, err)
return environment.NewEnableInternalServerError()
}
defer func() { _ = tx.Rollback() }()
defer func() { _ = trx.Rollback() }()
if err := h.checkLimits(principal); err != nil {
if err := h.checkLimits(principal, trx); err != nil {
logrus.Errorf("limits error for user '%v': %v", principal.Email, err)
return environment.NewEnableUnauthorized()
}
@ -67,14 +68,14 @@ func (h *enableHandler) Handle(params environment.EnableParams, principal *rest_
Host: params.Body.Host,
Address: realRemoteAddress(params.HTTPRequest),
ZId: envZId,
}, tx)
}, trx)
if err != nil {
logrus.Errorf("error storing created identity for user '%v': %v", principal.Email, err)
_ = tx.Rollback()
_ = trx.Rollback()
return environment.NewEnableInternalServerError()
}
if err := tx.Commit(); err != nil {
if err := trx.Commit(); err != nil {
logrus.Errorf("error committing for user '%v': %v", principal.Email, err)
return environment.NewEnableInternalServerError()
}
@ -96,15 +97,15 @@ func (h *enableHandler) Handle(params environment.EnableParams, principal *rest_
return resp
}
func (h *enableHandler) checkLimits(principal *rest_model_zrok.Principal) error {
func (h *enableHandler) checkLimits(principal *rest_model_zrok.Principal, trx *sqlx.Tx) error {
if !principal.Limitless {
if limitsAgent != nil {
ok, err := limitsAgent.CanCreateEnvironment(int(principal.ID))
ok, err := limitsAgent.CanCreateEnvironment(int(principal.ID), trx)
if err != nil {
return errors.Wrapf(err, "error checking limits for '%v'", principal.Email)
return errors.Wrapf(err, "error checking environment limits for '%v'", principal.Email)
}
if !ok {
return errors.Wrapf(err, "environment limit check failed for '%v'", principal.Email)
return errors.Errorf("environment limit check failed for '%v'", principal.Email)
}
}
}

View File

@ -43,14 +43,8 @@ func (a *Agent) Stop() {
<-a.join
}
func (a *Agent) CanCreateEnvironment(acctId int) (bool, error) {
if a.cfg.Environments > Unlimited {
trx, err := a.str.Begin()
if err != nil {
return false, errors.Wrap(err, "error creating transaction")
}
defer func() { _ = trx.Rollback() }()
func (a *Agent) CanCreateEnvironment(acctId int, trx *sqlx.Tx) (bool, error) {
if a.cfg.Enforcing && a.cfg.Environments > Unlimited {
envs, err := a.str.FindEnvironmentsForAccount(acctId, trx)
if err != nil {
return false, err
@ -62,6 +56,28 @@ func (a *Agent) CanCreateEnvironment(acctId int) (bool, error) {
return true, nil
}
func (a *Agent) CanCreateShare(acctId int, trx *sqlx.Tx) (bool, error) {
if a.cfg.Enforcing && a.cfg.Shares > Unlimited {
envs, err := a.str.FindEnvironmentsForAccount(acctId, trx)
if err != nil {
return false, err
}
total := 0
for i := range envs {
shrs, err := a.str.FindSharesForEnvironment(envs[i].Id, trx)
if err != nil {
return false, errors.Wrapf(err, "unable to find shares for environment '%v'", envs[i].ZId)
}
total += len(shrs)
if total+1 > a.cfg.Shares {
return false, nil
}
logrus.Infof("total = %d", total)
}
}
return true, nil
}
func (a *Agent) Handle(u *metrics.Usage) error {
logrus.Debugf("handling: %v", u)
a.queue <- u

View File

@ -3,7 +3,6 @@ package controller
import (
"github.com/go-openapi/runtime/middleware"
"github.com/jmoiron/sqlx"
"github.com/openziti/zrok/controller/limits"
"github.com/openziti/zrok/controller/store"
"github.com/openziti/zrok/controller/zrokEdgeSdk"
"github.com/openziti/zrok/rest_model_zrok"
@ -12,27 +11,23 @@ import (
"github.com/sirupsen/logrus"
)
type shareHandler struct {
cfg *limits.Config
}
type shareHandler struct{}
func newShareHandler(cfg *limits.Config) *shareHandler {
return &shareHandler{cfg: cfg}
func newShareHandler() *shareHandler {
return &shareHandler{}
}
func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zrok.Principal) middleware.Responder {
logrus.Infof("handling")
tx, err := str.Begin()
trx, err := str.Begin()
if err != nil {
logrus.Errorf("error starting transaction: %v", err)
return share.NewShareInternalServerError()
}
defer func() { _ = tx.Rollback() }()
defer func() { _ = trx.Rollback() }()
envZId := params.Body.EnvZID
envId := 0
envs, err := str.FindEnvironmentsForAccount(int(principal.ID), tx)
envs, err := str.FindEnvironmentsForAccount(int(principal.ID), trx)
if err == nil {
found := false
for _, env := range envs {
@ -52,7 +47,7 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
return share.NewShareInternalServerError()
}
if err := h.checkLimits(principal, envs, tx); err != nil {
if err := h.checkLimits(principal, trx); err != nil {
logrus.Errorf("limits error: %v", err)
return share.NewShareUnauthorized()
}
@ -80,7 +75,7 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
var frontendZIds []string
var frontendTemplates []string
for _, frontendSelection := range params.Body.FrontendSelection {
sfe, err := str.FindFrontendPubliclyNamed(frontendSelection, tx)
sfe, err := str.FindFrontendPubliclyNamed(frontendSelection, trx)
if err != nil {
logrus.Error(err)
return share.NewShareNotFound()
@ -126,13 +121,13 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
sshr.FrontendEndpoint = &sshr.ShareMode
}
sid, err := str.CreateShare(envId, sshr, tx)
sid, err := str.CreateShare(envId, sshr, trx)
if err != nil {
logrus.Errorf("error creating share record: %v", err)
return share.NewShareInternalServerError()
}
if err := tx.Commit(); err != nil {
if err := trx.Commit(); err != nil {
logrus.Errorf("error committing share record: %v", err)
return share.NewShareInternalServerError()
}
@ -144,17 +139,15 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
})
}
func (h *shareHandler) checkLimits(principal *rest_model_zrok.Principal, envs []*store.Environment, tx *sqlx.Tx) error {
if !principal.Limitless && h.cfg.Shares > limits.Unlimited {
total := 0
for i := range envs {
shrs, err := str.FindSharesForEnvironment(envs[i].Id, tx)
func (h *shareHandler) checkLimits(principal *rest_model_zrok.Principal, trx *sqlx.Tx) error {
if !principal.Limitless {
if limitsAgent != nil {
ok, err := limitsAgent.CanCreateShare(int(principal.ID), trx)
if err != nil {
return errors.Errorf("unable to find shares for environment '%v': %v", envs[i].ZId, err)
return errors.Wrapf(err, "error checking share limits for '%v'", principal.Email)
}
total += len(shrs)
if total+1 > h.cfg.Shares {
return errors.Errorf("would exceed shares limit of %d for '%v'", h.cfg.Shares, principal.Email)
if !ok {
return errors.Errorf("share limit check failed for '%v'", principal.Email)
}
}
}