mirror of
https://github.com/openziti/zrok.git
synced 2024-12-22 14:50:55 +01:00
roughed-in access handler (#111)
This commit is contained in:
parent
f47d97d103
commit
09c603845c
75
controller/access.go
Normal file
75
controller/access.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-openapi/runtime/middleware"
|
||||||
|
"github.com/openziti-test-kitchen/zrok/controller/store"
|
||||||
|
"github.com/openziti-test-kitchen/zrok/rest_model_zrok"
|
||||||
|
"github.com/openziti-test-kitchen/zrok/rest_server_zrok/operations/service"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type accessHandler struct{}
|
||||||
|
|
||||||
|
func newAccessHandler() *accessHandler {
|
||||||
|
return &accessHandler{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *accessHandler) Handle(params service.AccessParams, principal *rest_model_zrok.Principal) middleware.Responder {
|
||||||
|
tx, err := str.Begin()
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorf("error starting transaction: %v", err)
|
||||||
|
return service.NewAccessInternalServerError()
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
envZId := params.Body.ZID
|
||||||
|
envId := 0
|
||||||
|
if envs, err := str.FindEnvironmentsForAccount(int(principal.ID), tx); err == nil {
|
||||||
|
found := false
|
||||||
|
for _, env := range envs {
|
||||||
|
if env.ZId == envZId {
|
||||||
|
logrus.Debugf("found identity '%v' for user '%v'", envZId, principal.Email)
|
||||||
|
envId = env.Id
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
logrus.Errorf("environment '%v' not found for user '%v'", envZId, principal.Email)
|
||||||
|
return service.NewAccessUnauthorized()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logrus.Errorf("error finding environments for account '%v'", principal.Email)
|
||||||
|
return service.NewAccessInternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
ssvcs, err := str.FindServicesForEnvironment(envId, tx)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorf("error finding services for environment")
|
||||||
|
return service.NewAccessInternalServerError()
|
||||||
|
}
|
||||||
|
var ssvc *store.Service
|
||||||
|
for _, v := range ssvcs {
|
||||||
|
if v.Name == params.Body.SvcName {
|
||||||
|
ssvc = v
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ssvc == nil {
|
||||||
|
logrus.Errorf("unable to find service '%v' for user '%v'", params.Body.SvcName, principal.Email)
|
||||||
|
return service.NewAccessNotFound()
|
||||||
|
}
|
||||||
|
|
||||||
|
edge, err := edgeClient()
|
||||||
|
if err != nil {
|
||||||
|
logrus.Error(err)
|
||||||
|
return service.NewAccessInternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := createServicePolicyDial(envZId, ssvc.Name, ssvc.ZId, edge); err != nil {
|
||||||
|
logrus.Errorf("unable to create dial policy: %v", err)
|
||||||
|
return service.NewAccessInternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return service.NewAccessCreated()
|
||||||
|
}
|
@ -34,6 +34,7 @@ func Run(inCfg *Config) error {
|
|||||||
api.IdentityVerifyHandler = newVerifyHandler()
|
api.IdentityVerifyHandler = newVerifyHandler()
|
||||||
api.MetadataOverviewHandler = metadata.OverviewHandlerFunc(overviewHandler)
|
api.MetadataOverviewHandler = metadata.OverviewHandlerFunc(overviewHandler)
|
||||||
api.MetadataVersionHandler = metadata.VersionHandlerFunc(versionHandler)
|
api.MetadataVersionHandler = metadata.VersionHandlerFunc(versionHandler)
|
||||||
|
api.ServiceAccessHandler = newAccessHandler()
|
||||||
api.ServiceShareHandler = newShareHandler()
|
api.ServiceShareHandler = newShareHandler()
|
||||||
api.ServiceUnshareHandler = newUnshareHandler()
|
api.ServiceUnshareHandler = newUnshareHandler()
|
||||||
|
|
||||||
|
@ -32,6 +32,12 @@ func (o *AccessReader) ReadResponse(response runtime.ClientResponse, consumer ru
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return nil, result
|
return nil, result
|
||||||
|
case 404:
|
||||||
|
result := NewAccessNotFound()
|
||||||
|
if err := result.readResponse(response, consumer, o.formats); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nil, result
|
||||||
case 500:
|
case 500:
|
||||||
result := NewAccessInternalServerError()
|
result := NewAccessInternalServerError()
|
||||||
if err := result.readResponse(response, consumer, o.formats); err != nil {
|
if err := result.readResponse(response, consumer, o.formats); err != nil {
|
||||||
@ -145,6 +151,57 @@ func (o *AccessUnauthorized) readResponse(response runtime.ClientResponse, consu
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewAccessNotFound creates a AccessNotFound with default headers values
|
||||||
|
func NewAccessNotFound() *AccessNotFound {
|
||||||
|
return &AccessNotFound{}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
AccessNotFound describes a response with status code 404, with default header values.
|
||||||
|
|
||||||
|
not found
|
||||||
|
*/
|
||||||
|
type AccessNotFound struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSuccess returns true when this access not found response has a 2xx status code
|
||||||
|
func (o *AccessNotFound) IsSuccess() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRedirect returns true when this access not found response has a 3xx status code
|
||||||
|
func (o *AccessNotFound) IsRedirect() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClientError returns true when this access not found response has a 4xx status code
|
||||||
|
func (o *AccessNotFound) IsClientError() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsServerError returns true when this access not found response has a 5xx status code
|
||||||
|
func (o *AccessNotFound) IsServerError() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCode returns true when this access not found response a status code equal to that given
|
||||||
|
func (o *AccessNotFound) IsCode(code int) bool {
|
||||||
|
return code == 404
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *AccessNotFound) Error() string {
|
||||||
|
return fmt.Sprintf("[POST /access][%d] accessNotFound ", 404)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *AccessNotFound) String() string {
|
||||||
|
return fmt.Sprintf("[POST /access][%d] accessNotFound ", 404)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *AccessNotFound) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error {
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// NewAccessInternalServerError creates a AccessInternalServerError with default headers values
|
// NewAccessInternalServerError creates a AccessInternalServerError with default headers values
|
||||||
func NewAccessInternalServerError() *AccessInternalServerError {
|
func NewAccessInternalServerError() *AccessInternalServerError {
|
||||||
return &AccessInternalServerError{}
|
return &AccessInternalServerError{}
|
||||||
|
@ -62,6 +62,9 @@ func init() {
|
|||||||
"401": {
|
"401": {
|
||||||
"description": "unauthorized"
|
"description": "unauthorized"
|
||||||
},
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "not found"
|
||||||
|
},
|
||||||
"500": {
|
"500": {
|
||||||
"description": "internal server error"
|
"description": "internal server error"
|
||||||
}
|
}
|
||||||
@ -779,6 +782,9 @@ func init() {
|
|||||||
"401": {
|
"401": {
|
||||||
"description": "unauthorized"
|
"description": "unauthorized"
|
||||||
},
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "not found"
|
||||||
|
},
|
||||||
"500": {
|
"500": {
|
||||||
"description": "internal server error"
|
"description": "internal server error"
|
||||||
}
|
}
|
||||||
|
@ -61,6 +61,31 @@ func (o *AccessUnauthorized) WriteResponse(rw http.ResponseWriter, producer runt
|
|||||||
rw.WriteHeader(401)
|
rw.WriteHeader(401)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccessNotFoundCode is the HTTP code returned for type AccessNotFound
|
||||||
|
const AccessNotFoundCode int = 404
|
||||||
|
|
||||||
|
/*
|
||||||
|
AccessNotFound not found
|
||||||
|
|
||||||
|
swagger:response accessNotFound
|
||||||
|
*/
|
||||||
|
type AccessNotFound struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAccessNotFound creates AccessNotFound with default headers values
|
||||||
|
func NewAccessNotFound() *AccessNotFound {
|
||||||
|
|
||||||
|
return &AccessNotFound{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteResponse to the client
|
||||||
|
func (o *AccessNotFound) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) {
|
||||||
|
|
||||||
|
rw.Header().Del(runtime.HeaderContentType) //Remove Content-Type on empty responses
|
||||||
|
|
||||||
|
rw.WriteHeader(404)
|
||||||
|
}
|
||||||
|
|
||||||
// AccessInternalServerErrorCode is the HTTP code returned for type AccessInternalServerError
|
// AccessInternalServerErrorCode is the HTTP code returned for type AccessInternalServerError
|
||||||
const AccessInternalServerErrorCode int = 500
|
const AccessInternalServerErrorCode int = 500
|
||||||
|
|
||||||
|
@ -196,6 +196,8 @@ paths:
|
|||||||
description: access created
|
description: access created
|
||||||
401:
|
401:
|
||||||
description: unauthorized
|
description: unauthorized
|
||||||
|
404:
|
||||||
|
description: not found
|
||||||
500:
|
500:
|
||||||
description: internal server error
|
description: internal server error
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user