better middleware management; return 429 when over limit (#968)

This commit is contained in:
Michael Quigley 2025-06-23 15:08:43 -04:00
parent 7b1d98f9ce
commit 784417bacc
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
7 changed files with 140 additions and 42 deletions

View File

@ -34,17 +34,19 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
envId, err := h.validateEnvironment(params.Body.EnvZID, principal, trx) envId, err := h.validateEnvironment(params.Body.EnvZID, principal, trx)
if err != nil { if err != nil {
return h.handleEnvironmentError(err) logrus.Errorf("error validating environment '%v' for '%v': %v", params.Body.EnvZID, principal.Email, err)
}
if err := h.checkLimits(envId, principal, params.Body.Reserved, params.Body.UniqueName != "", sdk.ShareMode(params.Body.ShareMode), sdk.BackendMode(params.Body.BackendMode), trx); err != nil {
logrus.Errorf("limits error: %v", err)
return share.NewShareUnauthorized() return share.NewShareUnauthorized()
} }
accessGrantAcctIds, responder := h.processAccessGrants(params, principal, trx) if err := h.checkLimits(envId, principal, params.Body.Reserved, params.Body.UniqueName != "", sdk.ShareMode(params.Body.ShareMode), sdk.BackendMode(params.Body.BackendMode), trx); err != nil {
if responder != nil { logrus.Errorf("limits error for '%v': %v", principal.Email, err)
return responder return share.NewShareTooManyRequests()
}
accessGrantAcctIds, err := h.processAccessGrants(params, principal, trx)
if err != nil {
logrus.Errorf("error processing access grants: %v", err)
return share.NewShareInternalServerError()
} }
edge, err := zrokEdgeSdk.Client(cfg.Ziti) edge, err := zrokEdgeSdk.Client(cfg.Ziti)
@ -53,25 +55,29 @@ func (h *shareHandler) Handle(params share.ShareParams, principal *rest_model_zr
return share.NewShareInternalServerError() return share.NewShareInternalServerError()
} }
shrToken, responder := h.createShareToken(params.Body.Reserved, params.Body.UniqueName, trx) shrToken, err := h.createShareToken(params.Body.Reserved, params.Body.UniqueName, trx)
if responder != nil { if err != nil {
return responder logrus.Errorf("error creating share token: %v", err)
return share.NewShareInternalServerError()
} }
shrZId, frontendEndpoints, responder := h.allocateResources(params, principal, edge, shrToken, trx) shrZId, frontendEndpoints, err := h.allocateResources(params, principal, edge, shrToken, trx)
if responder != nil { if err != nil {
return responder logrus.Errorf("error allocating resources: %v", err)
return share.NewShareInternalServerError()
} }
sshr := h.createShareRecord(shrZId, shrToken, params, frontendEndpoints) sshr := h.createShareRecord(shrZId, shrToken, params, frontendEndpoints)
sid, responder := h.saveShareAndGrants(sshr, envId, accessGrantAcctIds, trx) sid, err := h.saveShareAndGrants(sshr, envId, accessGrantAcctIds, trx)
if responder != nil { if err != nil {
return responder logrus.Errorf("error saving share and grants: %v", err)
return share.NewShareInternalServerError()
} }
if responder := h.handleAuthSecrets(params, sid, sshr, trx); responder != nil { if err := h.handleAuthSecrets(params, sid, sshr, trx); err != nil {
return responder logrus.Errorf("error handling auth secrets: %v", err)
return share.NewShareInternalServerError()
} }
if err := trx.Commit(); err != nil { if err := trx.Commit(); err != nil {
@ -104,21 +110,14 @@ func (h *shareHandler) validateEnvironment(envZId string, principal *rest_model_
return 0, errors.New("environment not found") return 0, errors.New("environment not found")
} }
func (h *shareHandler) handleEnvironmentError(err error) middleware.Responder { func (h *shareHandler) processAccessGrants(params share.ShareParams, principal *rest_model_zrok.Principal, trx *sqlx.Tx) ([]int, error) {
if err.Error() == "environment not found" {
return share.NewShareUnauthorized()
}
return share.NewShareInternalServerError()
}
func (h *shareHandler) processAccessGrants(params share.ShareParams, principal *rest_model_zrok.Principal, trx *sqlx.Tx) ([]int, middleware.Responder) {
var accessGrantAcctIds []int var accessGrantAcctIds []int
if store.PermissionMode(params.Body.PermissionMode) == store.ClosedPermissionMode { if store.PermissionMode(params.Body.PermissionMode) == store.ClosedPermissionMode {
for _, email := range params.Body.AccessGrants { for _, email := range params.Body.AccessGrants {
acct, err := str.FindAccountWithEmail(email, trx) acct, err := str.FindAccountWithEmail(email, trx)
if err != nil { if err != nil {
logrus.Errorf("unable to find account '%v' for share request from '%v'", email, principal.Email) logrus.Errorf("unable to find account '%v' for share request from '%v'", email, principal.Email)
return nil, share.NewShareNotFound() return nil, err
} }
logrus.Debugf("found id '%d' for '%v'", acct.Id, acct.Email) logrus.Debugf("found id '%d' for '%v'", acct.Id, acct.Email)
accessGrantAcctIds = append(accessGrantAcctIds, acct.Id) accessGrantAcctIds = append(accessGrantAcctIds, acct.Id)
@ -127,35 +126,35 @@ func (h *shareHandler) processAccessGrants(params share.ShareParams, principal *
return accessGrantAcctIds, nil return accessGrantAcctIds, nil
} }
func (h *shareHandler) createShareToken(reserved bool, uniqueName string, trx *sqlx.Tx) (string, middleware.Responder) { func (h *shareHandler) createShareToken(reserved bool, uniqueName string, trx *sqlx.Tx) (string, error) {
if !reserved || uniqueName == "" { if !reserved || uniqueName == "" {
token, err := createShareToken() token, err := createShareToken()
if err != nil { if err != nil {
logrus.Error(err) logrus.Error(err)
return "", share.NewShareInternalServerError() return "", err
} }
return token, nil return token, nil
} }
if !util.IsValidUniqueName(uniqueName) { if !util.IsValidUniqueName(uniqueName) {
logrus.Errorf("invalid unique name '%v'", uniqueName) logrus.Errorf("invalid unique name '%v'", uniqueName)
return "", share.NewShareUnprocessableEntity() return "", errors.New("invalid unique name")
} }
shareExists, err := str.ShareWithTokenExists(uniqueName, trx) shareExists, err := str.ShareWithTokenExists(uniqueName, trx)
if err != nil { if err != nil {
logrus.Errorf("error checking share for token collision: %v", err) logrus.Errorf("error checking share for token collision: %v", err)
return "", share.NewUpdateShareInternalServerError() return "", err
} }
if shareExists { if shareExists {
logrus.Errorf("token '%v' already exists; cannot create share", uniqueName) logrus.Errorf("token '%v' already exists; cannot create share", uniqueName)
return "", share.NewShareConflict() return "", errors.New("token already exists")
} }
return uniqueName, nil return uniqueName, nil
} }
func (h *shareHandler) allocateResources(params share.ShareParams, principal *rest_model_zrok.Principal, edge *rest_management_api_client.ZitiEdgeManagement, shrToken string, trx *sqlx.Tx) (string, []string, middleware.Responder) { func (h *shareHandler) allocateResources(params share.ShareParams, principal *rest_model_zrok.Principal, edge *rest_management_api_client.ZitiEdgeManagement, shrToken string, trx *sqlx.Tx) (string, []string, error) {
var shrZId string var shrZId string
var frontendEndpoints []string var frontendEndpoints []string
var err error var err error
@ -167,12 +166,12 @@ func (h *shareHandler) allocateResources(params share.ShareParams, principal *re
shrZId, frontendEndpoints, err = h.allocatePrivateResources(params, edge, shrToken) shrZId, frontendEndpoints, err = h.allocatePrivateResources(params, edge, shrToken)
default: default:
logrus.Errorf("unknown share mode '%v'", params.Body.ShareMode) logrus.Errorf("unknown share mode '%v'", params.Body.ShareMode)
return "", nil, share.NewShareInternalServerError() return "", nil, errors.New("unknown share mode")
} }
if err != nil { if err != nil {
logrus.Error(err) logrus.Error(err)
return "", nil, share.NewShareInternalServerError() return "", nil, err
} }
return shrZId, frontendEndpoints, nil return shrZId, frontendEndpoints, nil
@ -257,11 +256,11 @@ func (h *shareHandler) createShareRecord(shrZId string, shrToken string, params
return sshr return sshr
} }
func (h *shareHandler) saveShareAndGrants(sshr *store.Share, envId int, accessGrantAcctIds []int, trx *sqlx.Tx) (int, middleware.Responder) { func (h *shareHandler) saveShareAndGrants(sshr *store.Share, envId int, accessGrantAcctIds []int, trx *sqlx.Tx) (int, error) {
sid, err := str.CreateShare(envId, sshr, trx) sid, err := str.CreateShare(envId, sshr, trx)
if err != nil { if err != nil {
logrus.Errorf("error creating share record: %v", err) logrus.Errorf("error creating share record: %v", err)
return 0, share.NewShareInternalServerError() return 0, err
} }
if sshr.PermissionMode == store.ClosedPermissionMode { if sshr.PermissionMode == store.ClosedPermissionMode {
@ -269,7 +268,7 @@ func (h *shareHandler) saveShareAndGrants(sshr *store.Share, envId int, accessGr
_, err := str.CreateAccessGrant(sid, acctId, trx) _, err := str.CreateAccessGrant(sid, acctId, trx)
if err != nil { if err != nil {
logrus.Errorf("error creating access grant: %v", err) logrus.Errorf("error creating access grant: %v", err)
return 0, share.NewShareInternalServerError() return 0, err
} }
} }
} }
@ -277,7 +276,7 @@ func (h *shareHandler) saveShareAndGrants(sshr *store.Share, envId int, accessGr
return sid, nil return sid, nil
} }
func (h *shareHandler) handleAuthSecrets(params share.ShareParams, sid int, sshr *store.Share, trx *sqlx.Tx) middleware.Responder { func (h *shareHandler) handleAuthSecrets(params share.ShareParams, sid int, sshr *store.Share, trx *sqlx.Tx) error {
if sshr.ShareMode == string(sdk.PublicShareMode) && params.Body.AuthScheme == string(sdk.Basic) { if sshr.ShareMode == string(sdk.PublicShareMode) && params.Body.AuthScheme == string(sdk.Basic) {
logrus.Infof("writing basic auth secrets for '%v'", sshr.Token) logrus.Infof("writing basic auth secrets for '%v'", sshr.Token)
authUsersMap := make(map[string]string) authUsersMap := make(map[string]string)
@ -287,7 +286,7 @@ func (h *shareHandler) handleAuthSecrets(params share.ShareParams, sid int, sshr
authUsersMapJson, err := json.Marshal(authUsersMap) authUsersMapJson, err := json.Marshal(authUsersMap)
if err != nil { if err != nil {
logrus.Errorf("error marshalling auth secrets for '%v': %v", sshr.Token, err) logrus.Errorf("error marshalling auth secrets for '%v': %v", sshr.Token, err)
return share.NewShareInternalServerError() return err
} }
secrets := store.Secrets{ secrets := store.Secrets{
ShareId: sid, ShareId: sid,
@ -298,7 +297,7 @@ func (h *shareHandler) handleAuthSecrets(params share.ShareParams, sid int, sshr
} }
if err := str.CreateSecrets(secrets, trx); err != nil { if err := str.CreateSecrets(secrets, trx); err != nil {
logrus.Errorf("error creating secrets: %v", err) logrus.Errorf("error creating secrets: %v", err)
return share.NewShareInternalServerError() return err
} }
logrus.Infof("wrote auth secrets for '%v'", sshr.Token) logrus.Infof("wrote auth secrets for '%v'", sshr.Token)
} }

View File

@ -53,6 +53,12 @@ func (o *ShareReader) ReadResponse(response runtime.ClientResponse, consumer run
return nil, err return nil, err
} }
return nil, result return nil, result
case 429:
result := NewShareTooManyRequests()
if err := result.readResponse(response, consumer, o.formats); err != nil {
return nil, err
}
return nil, result
case 500: case 500:
result := NewShareInternalServerError() result := NewShareInternalServerError()
if err := result.readResponse(response, consumer, o.formats); err != nil { if err := result.readResponse(response, consumer, o.formats); err != nil {
@ -356,6 +362,62 @@ func (o *ShareUnprocessableEntity) readResponse(response runtime.ClientResponse,
return nil return nil
} }
// NewShareTooManyRequests creates a ShareTooManyRequests with default headers values
func NewShareTooManyRequests() *ShareTooManyRequests {
return &ShareTooManyRequests{}
}
/*
ShareTooManyRequests describes a response with status code 429, with default header values.
over limit
*/
type ShareTooManyRequests struct {
}
// IsSuccess returns true when this share too many requests response has a 2xx status code
func (o *ShareTooManyRequests) IsSuccess() bool {
return false
}
// IsRedirect returns true when this share too many requests response has a 3xx status code
func (o *ShareTooManyRequests) IsRedirect() bool {
return false
}
// IsClientError returns true when this share too many requests response has a 4xx status code
func (o *ShareTooManyRequests) IsClientError() bool {
return true
}
// IsServerError returns true when this share too many requests response has a 5xx status code
func (o *ShareTooManyRequests) IsServerError() bool {
return false
}
// IsCode returns true when this share too many requests response a status code equal to that given
func (o *ShareTooManyRequests) IsCode(code int) bool {
return code == 429
}
// Code gets the status code for the share too many requests response
func (o *ShareTooManyRequests) Code() int {
return 429
}
func (o *ShareTooManyRequests) Error() string {
return fmt.Sprintf("[POST /share][%d] shareTooManyRequests ", 429)
}
func (o *ShareTooManyRequests) String() string {
return fmt.Sprintf("[POST /share][%d] shareTooManyRequests ", 429)
}
func (o *ShareTooManyRequests) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error {
return nil
}
// NewShareInternalServerError creates a ShareInternalServerError with default headers values // NewShareInternalServerError creates a ShareInternalServerError with default headers values
func NewShareInternalServerError() *ShareInternalServerError { func NewShareInternalServerError() *ShareInternalServerError {
return &ShareInternalServerError{} return &ShareInternalServerError{}

View File

@ -2298,6 +2298,9 @@ func init() {
"422": { "422": {
"description": "unprocessable" "description": "unprocessable"
}, },
"429": {
"description": "over limit"
},
"500": { "500": {
"description": "internal server error", "description": "internal server error",
"schema": { "schema": {
@ -5194,6 +5197,9 @@ func init() {
"422": { "422": {
"description": "unprocessable" "description": "unprocessable"
}, },
"429": {
"description": "over limit"
},
"500": { "500": {
"description": "internal server error", "description": "internal server error",
"schema": { "schema": {

View File

@ -158,6 +158,31 @@ func (o *ShareUnprocessableEntity) WriteResponse(rw http.ResponseWriter, produce
rw.WriteHeader(422) rw.WriteHeader(422)
} }
// ShareTooManyRequestsCode is the HTTP code returned for type ShareTooManyRequests
const ShareTooManyRequestsCode int = 429
/*
ShareTooManyRequests over limit
swagger:response shareTooManyRequests
*/
type ShareTooManyRequests struct {
}
// NewShareTooManyRequests creates ShareTooManyRequests with default headers values
func NewShareTooManyRequests() *ShareTooManyRequests {
return &ShareTooManyRequests{}
}
// WriteResponse to the client
func (o *ShareTooManyRequests) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) {
rw.Header().Del(runtime.HeaderContentType) //Remove Content-Type on empty responses
rw.WriteHeader(429)
}
// ShareInternalServerErrorCode is the HTTP code returned for type ShareInternalServerError // ShareInternalServerErrorCode is the HTTP code returned for type ShareInternalServerError
const ShareInternalServerErrorCode int = 500 const ShareInternalServerErrorCode int = 500

View File

@ -166,6 +166,7 @@ Name | Type | Description | Notes
**404** | not found | - | **404** | not found | - |
**409** | conflict | - | **409** | conflict | - |
**422** | unprocessable | - | **422** | unprocessable | - |
**429** | over limit | - |
**500** | internal server error | - | **500** | internal server error | - |
[[Back to top]](#) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to Model list]](../README.md#documentation-for-models) [[Back to README]](../README.md) [[Back to top]](#) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to Model list]](../README.md#documentation-for-models) [[Back to README]](../README.md)

View File

@ -382,6 +382,7 @@ class ShareApi:
'404': None, '404': None,
'409': None, '409': None,
'422': None, '422': None,
'429': None,
'500': "str", '500': "str",
} }
response_data = self.api_client.call_api( response_data = self.api_client.call_api(
@ -453,6 +454,7 @@ class ShareApi:
'404': None, '404': None,
'409': None, '409': None,
'422': None, '422': None,
'429': None,
'500': "str", '500': "str",
} }
response_data = self.api_client.call_api( response_data = self.api_client.call_api(
@ -524,6 +526,7 @@ class ShareApi:
'404': None, '404': None,
'409': None, '409': None,
'422': None, '422': None,
'429': None,
'500': "str", '500': "str",
} }
response_data = self.api_client.call_api( response_data = self.api_client.call_api(

View File

@ -1497,6 +1497,8 @@ paths:
description: conflict description: conflict
422: 422:
description: unprocessable description: unprocessable
429:
description: over limit
500: 500:
description: internal server error description: internal server error
schema: schema: