share limits check owned by limits.Agent now (#277)

This commit is contained in:
Michael Quigley 2023-03-21 16:34:45 -04:00
parent 79e9f484dc
commit d0dd04a141
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
4 changed files with 53 additions and 43 deletions

View File

@ -52,7 +52,7 @@ func Run(inCfg *config.Config) error {
api.MetadataOverviewHandler = metadata.OverviewHandlerFunc(overviewHandler) api.MetadataOverviewHandler = metadata.OverviewHandlerFunc(overviewHandler)
api.MetadataVersionHandler = metadata.VersionHandlerFunc(versionHandler) api.MetadataVersionHandler = metadata.VersionHandlerFunc(versionHandler)
api.ShareAccessHandler = newAccessHandler() api.ShareAccessHandler = newAccessHandler()
api.ShareShareHandler = newShareHandler(cfg.Limits) api.ShareShareHandler = newShareHandler()
api.ShareUnaccessHandler = newUnaccessHandler() api.ShareUnaccessHandler = newUnaccessHandler()
api.ShareUnshareHandler = newUnshareHandler() api.ShareUnshareHandler = newUnshareHandler()
api.ShareUpdateShareHandler = newUpdateShareHandler() api.ShareUpdateShareHandler = newUpdateShareHandler()

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/go-openapi/runtime/middleware" "github.com/go-openapi/runtime/middleware"
"github.com/jmoiron/sqlx"
"github.com/openziti/zrok/controller/store" "github.com/openziti/zrok/controller/store"
"github.com/openziti/zrok/controller/zrokEdgeSdk" "github.com/openziti/zrok/controller/zrokEdgeSdk"
"github.com/openziti/zrok/rest_model_zrok" "github.com/openziti/zrok/rest_model_zrok"
@ -20,14 +21,14 @@ func newEnableHandler() *enableHandler {
func (h *enableHandler) Handle(params environment.EnableParams, principal *rest_model_zrok.Principal) middleware.Responder { func (h *enableHandler) Handle(params environment.EnableParams, principal *rest_model_zrok.Principal) middleware.Responder {
// start transaction early; if it fails, don't bother creating ziti resources // start transaction early; if it fails, don't bother creating ziti resources
tx, err := str.Begin() trx, err := str.Begin()
if err != nil { if err != nil {
logrus.Errorf("error starting transaction for user '%v': %v", principal.Email, err) logrus.Errorf("error starting transaction for user '%v': %v", principal.Email, err)
return environment.NewEnableInternalServerError() return environment.NewEnableInternalServerError()
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = trx.Rollback() }()
if err := h.checkLimits(principal); err != nil { if err := h.checkLimits(principal, trx); err != nil {
logrus.Errorf("limits error for user '%v': %v", principal.Email, err) logrus.Errorf("limits error for user '%v': %v", principal.Email, err)
return environment.NewEnableUnauthorized() return environment.NewEnableUnauthorized()
} }
@ -67,14 +68,14 @@ func (h *enableHandler) Handle(params environment.EnableParams, principal *rest_
Host: params.Body.Host, Host: params.Body.Host,
Address: realRemoteAddress(params.HTTPRequest), Address: realRemoteAddress(params.HTTPRequest),
ZId: envZId, ZId: envZId,
}, tx) }, trx)
if err != nil { if err != nil {
logrus.Errorf("error storing created identity for user '%v': %v", principal.Email, err) logrus.Errorf("error storing created identity for user '%v': %v", principal.Email, err)
_ = tx.Rollback() _ = trx.Rollback()
return environment.NewEnableInternalServerError() return environment.NewEnableInternalServerError()
} }
if err := tx.Commit(); err != nil { if err := trx.Commit(); err != nil {
logrus.Errorf("error committing for user '%v': %v", principal.Email, err) logrus.Errorf("error committing for user '%v': %v", principal.Email, err)
return environment.NewEnableInternalServerError() return environment.NewEnableInternalServerError()
} }
@ -96,15 +97,15 @@ func (h *enableHandler) Handle(params environment.EnableParams, principal *rest_
return resp return resp
} }
func (h *enableHandler) checkLimits(principal *rest_model_zrok.Principal) error { func (h *enableHandler) checkLimits(principal *rest_model_zrok.Principal, trx *sqlx.Tx) error {
if !principal.Limitless { if !principal.Limitless {
if limitsAgent != nil { if limitsAgent != nil {
ok, err := limitsAgent.CanCreateEnvironment(int(principal.ID)) ok, err := limitsAgent.CanCreateEnvironment(int(principal.ID), trx)
if err != nil { if err != nil {
return errors.Wrapf(err, "error checking limits for '%v'", principal.Email) return errors.Wrapf(err, "error checking environment limits for '%v'", principal.Email)
} }
if !ok { if !ok {
return errors.Wrapf(err, "environment limit check failed for '%v'", principal.Email) return errors.Errorf("environment limit check failed for '%v'", principal.Email)
} }
} }
} }

View File

@ -43,14 +43,8 @@ func (a *Agent) Stop() {
<-a.join <-a.join
} }
func (a *Agent) CanCreateEnvironment(acctId int) (bool, error) { func (a *Agent) CanCreateEnvironment(acctId int, trx *sqlx.Tx) (bool, error) {
if a.cfg.Environments > Unlimited { if a.cfg.Enforcing && a.cfg.Environments > Unlimited {
trx, err := a.str.Begin()
if err != nil {
return false, errors.Wrap(err, "error creating transaction")
}
defer func() { _ = trx.Rollback() }()
envs, err := a.str.FindEnvironmentsForAccount(acctId, trx) envs, err := a.str.FindEnvironmentsForAccount(acctId, trx)
if err != nil { if err != nil {
return false, err return false, err
@ -62,6 +56,28 @@ func (a *Agent) CanCreateEnvironment(acctId int) (bool, error) {
return true, nil return true, nil
} }
func (a *Agent) CanCreateShare(acctId int, trx *sqlx.Tx) (bool, error) {
if a.cfg.Enforcing && a.cfg.Shares > Unlimited {
envs, err := a.str.FindEnvironmentsForAccount(acctId, trx)
if err != nil {
return false, err
}
total := 0
for i := range envs {
shrs, err := a.str.FindSharesForEnvironment(envs[i].Id, trx)
if err != nil {
return false, errors.Wrapf(err, "unable to find shares for environment '%v'", envs[i].ZId)
}
total += len(shrs)
if total+1 > a.cfg.Shares {
return false, nil
}
logrus.Infof("total = %d", total)
}
}
return true, nil
}
func (a *Agent) Handle(u *metrics.Usage) error { func (a *Agent) Handle(u *metrics.Usage) error {
logrus.Debugf("handling: %v", u) logrus.Debugf("handling: %v", u)
a.queue <- u a.queue <- u

View File

@ -3,7 +3,6 @@ package controller
import ( import (
"github.com/go-openapi/runtime/middleware" "github.com/go-openapi/runtime/middleware"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/openziti/zrok/controller/limits"
"github.com/openziti/zrok/controller/store" "github.com/openziti/zrok/controller/store"
"github.com/openziti/zrok/controller/zrokEdgeSdk" "github.com/openziti/zrok/controller/zrokEdgeSdk"
"github.com/openziti/zrok/rest_model_zrok" "github.com/openziti/zrok/rest_model_zrok"
@ -12,27 +11,23 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
type shareHandler struct { type shareHandler struct{}
cfg *limits.Config
}
func newShareHandler(cfg *limits.Config) *shareHandler { func newShareHandler() *shareHandler {
return &shareHandler{cfg: cfg} return &shareHandler{}
} }
func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zrok.Principal) middleware.Responder { func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zrok.Principal) middleware.Responder {
logrus.Infof("handling") trx, err := str.Begin()
tx, err := str.Begin()
if err != nil { if err != nil {
logrus.Errorf("error starting transaction: %v", err) logrus.Errorf("error starting transaction: %v", err)
return share.NewShareInternalServerError() return share.NewShareInternalServerError()
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = trx.Rollback() }()
envZId := params.Body.EnvZID envZId := params.Body.EnvZID
envId := 0 envId := 0
envs, err := str.FindEnvironmentsForAccount(int(principal.ID), tx) envs, err := str.FindEnvironmentsForAccount(int(principal.ID), trx)
if err == nil { if err == nil {
found := false found := false
for _, env := range envs { for _, env := range envs {
@ -52,7 +47,7 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
return share.NewShareInternalServerError() return share.NewShareInternalServerError()
} }
if err := h.checkLimits(principal, envs, tx); err != nil { if err := h.checkLimits(principal, trx); err != nil {
logrus.Errorf("limits error: %v", err) logrus.Errorf("limits error: %v", err)
return share.NewShareUnauthorized() return share.NewShareUnauthorized()
} }
@ -80,7 +75,7 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
var frontendZIds []string var frontendZIds []string
var frontendTemplates []string var frontendTemplates []string
for _, frontendSelection := range params.Body.FrontendSelection { for _, frontendSelection := range params.Body.FrontendSelection {
sfe, err := str.FindFrontendPubliclyNamed(frontendSelection, tx) sfe, err := str.FindFrontendPubliclyNamed(frontendSelection, trx)
if err != nil { if err != nil {
logrus.Error(err) logrus.Error(err)
return share.NewShareNotFound() return share.NewShareNotFound()
@ -126,13 +121,13 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
sshr.FrontendEndpoint = &sshr.ShareMode sshr.FrontendEndpoint = &sshr.ShareMode
} }
sid, err := str.CreateShare(envId, sshr, tx) sid, err := str.CreateShare(envId, sshr, trx)
if err != nil { if err != nil {
logrus.Errorf("error creating share record: %v", err) logrus.Errorf("error creating share record: %v", err)
return share.NewShareInternalServerError() return share.NewShareInternalServerError()
} }
if err := tx.Commit(); err != nil { if err := trx.Commit(); err != nil {
logrus.Errorf("error committing share record: %v", err) logrus.Errorf("error committing share record: %v", err)
return share.NewShareInternalServerError() return share.NewShareInternalServerError()
} }
@ -144,17 +139,15 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
}) })
} }
func (h *shareHandler) checkLimits(principal *rest_model_zrok.Principal, envs []*store.Environment, tx *sqlx.Tx) error { func (h *shareHandler) checkLimits(principal *rest_model_zrok.Principal, trx *sqlx.Tx) error {
if !principal.Limitless && h.cfg.Shares > limits.Unlimited { if !principal.Limitless {
total := 0 if limitsAgent != nil {
for i := range envs { ok, err := limitsAgent.CanCreateShare(int(principal.ID), trx)
shrs, err := str.FindSharesForEnvironment(envs[i].Id, tx)
if err != nil { if err != nil {
return errors.Errorf("unable to find shares for environment '%v': %v", envs[i].ZId, err) return errors.Wrapf(err, "error checking share limits for '%v'", principal.Email)
} }
total += len(shrs) if !ok {
if total+1 > h.cfg.Shares { return errors.Errorf("share limit check failed for '%v'", principal.Email)
return errors.Errorf("would exceed shares limit of %d for '%v'", h.cfg.Shares, principal.Email)
} }
} }
} }